# PSGD (Preconditioned Stochastic Gradient Descent)
Implementation of [PSGD optimizer](https://github.com/lixilinx/psgd_torch) in JAX (optax-style).
PSGD is a second-order optimizer originally created by Xi-Lin Li that uses a hessian-based
preconditioner and lie groups to improve convergence, generalization, and efficiency.
## Installation
```bash
pip install psgd-jax
```
## Usage
PSGD defaults to a gradient whitening type preconditioner (gg^T). In this case, you can use PSGD
like any other optax optimizer:
```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat # or low_rank_approximation, affine
def loss_fn(params, x):
return jnp.sum((params - x) ** 2)
params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])
# make optimizer and init state
opt = xmat(
learning_rate=1.0,
b1=0.0,
preconditioner_update_probability=1.0, # preconditioner update frequency
)
opt_state = opt.init(params)
def step(params, x, opt_state):
loss_val, grad = jax.value_and_grad(loss_fn)(params, x)
updates, opt_state = opt.update(grad, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_val
while True:
params, opt_state, loss_val = step(params, x, opt_state)
print(loss_val)
if loss_val < 1e-4:
print("yay")
break
# Expected output:
# 14.0
# 5.1563816
# 1.7376599
# 0.6118454
# 0.18457186
# 0.056664664
# 0.014270116
# 0.0027846962
# 0.00018843572
# 4.3836744e-06
# yay
```
However, PSGD is best used with a hessian vector product. If values are provided for PSGD's extra
update function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically
uses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random
vector used to calculate the hessian vector product, and `update_preconditioner` is a boolean
that tells PSGD whether we're updating the preconditioner this step (passed in real hvp and
vector) or not (passed in dummy hvp and vector).
The `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:
```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat # or low_rank_approximation, affine
from psgd_jax import hessian_helper
def loss_fn(params, x):
return jnp.sum((params - x) ** 2)
params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])
# make optimizer and init state
# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper
opt = xmat(
learning_rate=1.0,
b1=0.0,
)
opt_state = opt.init(params)
def step(key, params, x, opt_state):
# replace jax.value_and_grad with the hessian_helper:
key, subkey = jax.random.split(key)
loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(
subkey,
loss_fn,
params,
loss_fn_extra_args=(x,),
has_aux=False,
preconditioner_update_probability=1.0, # update frequency handled in hessian_helper
)
loss_val = loss_fn_out
# Pass hvp, random vector, and whether we're updating the preconditioner
# this step into the update function. PSGD will automatically switch to
# hessian-based preconditioning when these are provided.
updates, opt_state = opt.update(
grad,
opt_state,
Hvp=hvp,
vector=vector,
update_preconditioner=update_precond
)
params = optax.apply_updates(params, updates)
return key, params, opt_state, loss_val
key = jax.random.PRNGKey(0)
while True:
key, params, opt_state, loss_val = step(key, params, x, opt_state)
print(loss_val)
if loss_val < 1e-4:
print("yay")
break
# Expected output:
# 14.0
# 7.460699e-14
# yay
```
If `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less
often, but convergence could be slower.
## PSGD variants
`psgd_jax.xmat` `psgd_jax.low_rank_approximation` `psgd_jax.affine`
There are three variants of PSGD: XMat, which uses an x-shaped global preconditioner, LRA, which
uses a low-rank approximation global preconditioner, and Affine, which uses block diagonal or
diagonal preconditioners.
**XMat:**
XMat is very simple to use, uses global hessian information for its preconditioner, and has
memory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).
**LRA:**
Low rank approximation uses a low rank hessian for its preconditioner and can give very strong
results. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).
**Affine:**
Affine does not use global hessian information, but can be powerful nonetheless and possibly use
less memory than xmat or LRA. `max_size_triangular` and `max_skew_triangular` determine whether
a dimension's preconditioner is either block diagonal or diagonal.
For example, if `max_size_triangular` is set to 512 and a layer's is shape (1024, 16, 64), the
preconditioner shapes will be [diag, block_diag, block_diag] or [(1024,), (16, 16), (64, 64)]
because 1024 > 512.
If `max_skew_triangular` is set to 32 and a layer's shape is (1024, 3),
the preconditioner shapes will be [diag, block_diag] or [(1024,), (3, 3)] because 1024/3 is
greater than 32.
If `max_size_triangular` and `max_skew_triangular` are set to 0, the affine preconditioners
will be entirely diagonal and would use less memory than adam even with momentum.
## Notes on sharding:
For now PSGD does not explicitly handle any sharding, so intermediates would be handled naively by
JAX based on how users define in and out shardings. Our goal is to improve preconditioner shapes
and explicitly handle sharding for PSGD, especially for XMat and LRA, to make it more efficient
in distributed settings.
**Optimizer state shapes:**
Momentum is always same shape as params.
Affine might be the most out-of-the-box sharding friendly as it uses block diagonal or diagonal
preconditioners. For example, if a layer has shape (1024, 16, 64) and `max_size_triangular` is set
to 512, the preconditioner shapes will be `[(1024,), (16, 16), (64, 64)]`, which could be sharded as
the user sees fit.
XMat's preconditioners `a` and `b` are both of shape `(n_params,)`. If n_params is odd, or not divisible
by number of devices, dummy params could be added before optimizer init and update.
LRA's preconditioner shapes are `U=(n_params, rank)`, `V=(n_params, rank)`, and `d=(n_params, 1)`.
## Resources
PSGD papers and resources listed from Xi-Lin's repo
1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)
2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)
3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)
4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.)
5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)
## License
[![CC BY 4.0][cc-by-image]][cc-by]
This work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].
2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li
[cc-by]: http://creativecommons.org/licenses/by/4.0/
[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png
[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg
Raw data
{
"_id": null,
"home_page": null,
"name": "psgd-jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "python, machine learning, optimization, jax",
"author": "Evan Walters, Omead Pooladzandi, Xi-Lin Li",
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/ea/ea/fb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72/psgd_jax-0.1.5.tar.gz",
"platform": null,
"description": "# PSGD (Preconditioned Stochastic Gradient Descent)\n\nImplementation of [PSGD optimizer](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). \nPSGD is a second-order optimizer originally created by Xi-Lin Li that uses a hessian-based \npreconditioner and lie groups to improve convergence, generalization, and efficiency.\n\n\n## Installation\n\n```bash\npip install psgd-jax\n```\n\n## Usage\n\nPSGD defaults to a gradient whitening type preconditioner (gg^T). In this case, you can use PSGD \nlike any other optax optimizer:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat # or low_rank_approximation, affine\n\n\ndef loss_fn(params, x):\n return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\nopt = xmat(\n learning_rate=1.0,\n b1=0.0,\n preconditioner_update_probability=1.0, # preconditioner update frequency\n)\nopt_state = opt.init(params)\n\n\ndef step(params, x, opt_state):\n loss_val, grad = jax.value_and_grad(loss_fn)(params, x)\n updates, opt_state = opt.update(grad, opt_state)\n params = optax.apply_updates(params, updates)\n return params, opt_state, loss_val\n\n\nwhile True:\n params, opt_state, loss_val = step(params, x, opt_state)\n print(loss_val)\n if loss_val < 1e-4:\n print(\"yay\")\n break\n\n# Expected output:\n# 14.0\n# 5.1563816\n# 1.7376599\n# 0.6118454\n# 0.18457186\n# 0.056664664\n# 0.014270116\n# 0.0027846962\n# 0.00018843572\n# 4.3836744e-06\n# yay\n```\n\nHowever, PSGD is best used with a hessian vector product. If values are provided for PSGD's extra \nupdate function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically \nuses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random \nvector used to calculate the hessian vector product, and `update_preconditioner` is a boolean \nthat tells PSGD whether we're updating the preconditioner this step (passed in real hvp and \nvector) or not (passed in dummy hvp and vector).\n\nThe `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat # or low_rank_approximation, affine\nfrom psgd_jax import hessian_helper\n\n\ndef loss_fn(params, x):\n return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\n# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper\nopt = xmat(\n learning_rate=1.0,\n b1=0.0,\n)\nopt_state = opt.init(params)\n\n\ndef step(key, params, x, opt_state):\n # replace jax.value_and_grad with the hessian_helper:\n key, subkey = jax.random.split(key)\n loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(\n subkey,\n loss_fn,\n params,\n loss_fn_extra_args=(x,),\n has_aux=False,\n preconditioner_update_probability=1.0, # update frequency handled in hessian_helper\n )\n loss_val = loss_fn_out\n\n # Pass hvp, random vector, and whether we're updating the preconditioner \n # this step into the update function. PSGD will automatically switch to \n # hessian-based preconditioning when these are provided.\n updates, opt_state = opt.update(\n grad,\n opt_state,\n Hvp=hvp,\n vector=vector,\n update_preconditioner=update_precond\n )\n\n params = optax.apply_updates(params, updates)\n return key, params, opt_state, loss_val\n\n\nkey = jax.random.PRNGKey(0)\nwhile True:\n key, params, opt_state, loss_val = step(key, params, x, opt_state)\n print(loss_val)\n if loss_val < 1e-4:\n print(\"yay\")\n break\n\n# Expected output:\n# 14.0\n# 7.460699e-14\n# yay\n```\n\nIf `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less \noften, but convergence could be slower.\n\n## PSGD variants\n\n`psgd_jax.xmat` `psgd_jax.low_rank_approximation` `psgd_jax.affine`\n\nThere are three variants of PSGD: XMat, which uses an x-shaped global preconditioner, LRA, which \nuses a low-rank approximation global preconditioner, and Affine, which uses block diagonal or \ndiagonal preconditioners.\n\n**XMat:**\n\nXMat is very simple to use, uses global hessian information for its preconditioner, and has \nmemory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).\n\n**LRA:**\n\nLow rank approximation uses a low rank hessian for its preconditioner and can give very strong \nresults. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).\n\n**Affine:**\n\nAffine does not use global hessian information, but can be powerful nonetheless and possibly use \nless memory than xmat or LRA. `max_size_triangular` and `max_skew_triangular` determine whether \na dimension's preconditioner is either block diagonal or diagonal.\n\nFor example, if `max_size_triangular` is set to 512 and a layer's is shape (1024, 16, 64), the \npreconditioner shapes will be [diag, block_diag, block_diag] or [(1024,), (16, 16), (64, 64)] \nbecause 1024 > 512.\n\nIf `max_skew_triangular` is set to 32 and a layer's shape is (1024, 3), \nthe preconditioner shapes will be [diag, block_diag] or [(1024,), (3, 3)] because 1024/3 is \ngreater than 32.\n\nIf `max_size_triangular` and `max_skew_triangular` are set to 0, the affine preconditioners\nwill be entirely diagonal and would use less memory than adam even with momentum.\n\n\n## Notes on sharding:\n\nFor now PSGD does not explicitly handle any sharding, so intermediates would be handled naively by \nJAX based on how users define in and out shardings. Our goal is to improve preconditioner shapes \nand explicitly handle sharding for PSGD, especially for XMat and LRA, to make it more efficient\nin distributed settings.\n\n**Optimizer state shapes:**\n\nMomentum is always same shape as params.\n\nAffine might be the most out-of-the-box sharding friendly as it uses block diagonal or diagonal \npreconditioners. For example, if a layer has shape (1024, 16, 64) and `max_size_triangular` is set \nto 512, the preconditioner shapes will be `[(1024,), (16, 16), (64, 64)]`, which could be sharded as \nthe user sees fit.\n\nXMat's preconditioners `a` and `b` are both of shape `(n_params,)`. If n_params is odd, or not divisible \nby number of devices, dummy params could be added before optimizer init and update.\n\nLRA's preconditioner shapes are `U=(n_params, rank)`, `V=(n_params, rank)`, and `d=(n_params, 1)`.\n\n\n## Resources\n\nPSGD papers and resources listed from Xi-Lin's repo\n\n1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)\n2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)\n3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)\n4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\\rm GL}(n, \\mathbb{R})/R_{\\rm polar}$.)\n5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)\n\n\n## License\n\n[![CC BY 4.0][cc-by-image]][cc-by]\n\nThis work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].\n\n2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li\n\n\n[cc-by]: http://creativecommons.org/licenses/by/4.0/\n[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png\n[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg\n\n",
"bugtrack_url": null,
"license": null,
"summary": "An implementation of PSGD optimizer in JAX.",
"version": "0.1.5",
"project_urls": {
"homepage": "https://github.com/evanatyourservice/psgd_jax",
"repository": "https://github.com/evanatyourservice/psgd_jax"
},
"split_keywords": [
"python",
" machine learning",
" optimization",
" jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "8c6535a29aee64e4e3b79f95361854c089048d24b7782133dcdad561a7bb7283",
"md5": "64756d3ddb8c1990c2d136a402a07a1b",
"sha256": "7a07ccc4f68ecf572c41a1cd25f8f082fcb2073c2adad8fce3d58bbdac4c0d4b"
},
"downloads": -1,
"filename": "psgd_jax-0.1.5-py3-none-any.whl",
"has_sig": false,
"md5_digest": "64756d3ddb8c1990c2d136a402a07a1b",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 27183,
"upload_time": "2024-08-22T16:15:25",
"upload_time_iso_8601": "2024-08-22T16:15:25.859422Z",
"url": "https://files.pythonhosted.org/packages/8c/65/35a29aee64e4e3b79f95361854c089048d24b7782133dcdad561a7bb7283/psgd_jax-0.1.5-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "eaeafb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72",
"md5": "1f96c042ade2c37e8cc111079bb810db",
"sha256": "1d4b5dbd3678e54a75fe8dd698926816205d12ebe3f76c974486eb6d574e4fcc"
},
"downloads": -1,
"filename": "psgd_jax-0.1.5.tar.gz",
"has_sig": false,
"md5_digest": "1f96c042ade2c37e8cc111079bb810db",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 23947,
"upload_time": "2024-08-22T16:15:27",
"upload_time_iso_8601": "2024-08-22T16:15:27.275024Z",
"url": "https://files.pythonhosted.org/packages/ea/ea/fb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72/psgd_jax-0.1.5.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-22 16:15:27",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "evanatyourservice",
"github_project": "psgd_jax",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "psgd-jax"
}