# PSGD Kron
For original PSGD repo, see [psgd_torch](https://github.com/lixilinx/psgd_torch).
For JAX version, see [psgd_jax](https://github.com/evanatyourservice/psgd_jax).
Implementation of [PSGD Kron optimizer](https://github.com/lixilinx/psgd_torch) in PyTorch.
PSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based
or whitening-based (gg^T) preconditioner and lie groups to improve training convergence,
generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked
to above for interesting details on how PSGD works and experiments using PSGD.
### `kron`:
The most versatile and easy-to-use PSGD optimizer is `kron`, which uses a
Kronecker-factored preconditioner. It has less hyperparameters that need tuning than adam, and can
be a drop-in replacement for adam. It keeps a dim's preconditioner as either triangular
or diagonal based on `max_size_triangular` and `max_skew_triangular`. For example, for a layer
with shape (256, 128, 64), triangular preconditioners would be shapes (256, 256), (128, 128), and
(64, 64) and diagonal preconditioners would be shapes (256,), (128,), and (64,). Depending on how
these two settings are chosen, `kron` can balance between memory/speed and performance (see below).
## Installation
```bash
pip install kron-torch
```
## Basic Usage (Kron)
For basic usage, use `kron` optimizer like any other pytorch optimizer:
```python
from kron_torch import Kron
optimizer = Kron(params)
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
**Basic hyperparameters:**
TLDR: Learning rate acts similarly to adam's, but can be set a little higher like 0.001 ->
0.003. Weight decay should be set lower than adam's, like 0.1 -> 0.03 or 0.01. There is no
b2 or epsilon.
`learning_rate`: Kron's learning rate acts similarly to adam's, but can withstand a higher
learning rate. Try setting 3x higher. If 0.001 was best for adam, try setting kron's to 0.003.
`weight_decay`: PSGD does not rely on weight decay for generalization as much as adam, and too
high weight decay can hurt performance. Try setting 3-10x lower. If the best weight decay for
adam was 0.1, you can set kron's to 0.03 or 0.01.
`max_size_triangular`: Anything above this value will have a diagonal preconditioner, anything
below will have a triangular preconditioner. So if you have a dim with size 16,384 that you want
to use a diagonal preconditioner for, set `max_size_triangular` to something like 15,000. Default
is 8192.
`max_skew_triangular`: Any tensor with skew above this value with make the larger dim diagonal.
For example, with the default value for `max_skew_triangular` as 10, a bias layer of shape
(256,) would be diagonal because 256/1 > 10, and an embedding dim of shape (50000, 768) would
be (diag, tri) because 50000/768 is greater than 10. The default value of 10 usually makes
layers like bias, scale, and vocab embedding use diagonal with the rest as triangular.
Interesting note: Setting `max_skew_triangular` to 0 will make all layers have (diag, tri)
preconditioners which uses slightly less memory than adam (and runs slightly faster). Setting
`max_size_triangular` to 0 will make all layers have diagonal preconditioners which uses the least
memory and runs the fastest, but training might be slower.
`preconditioner_update_probability`: Preconditioner update probability uses a schedule by default
that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training
will be slightly slower at the start but will speed up to near adam's speed by around 3k steps.
See kron.py for more hyperparameter details.
Raw data
{
"_id": null,
"home_page": null,
"name": "kron-torch",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "python, machine learning, optimization, pytorch",
"author": "Evan Walters, Omead Pooladzandi, Xi-Lin Li",
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/85/3e/9279765cd7818d1a01594628089a7f61a5fd354d5fd9140c1eaf7361b1a9/kron_torch-0.1.3.tar.gz",
"platform": null,
"description": "# PSGD Kron\n\nFor original PSGD repo, see [psgd_torch](https://github.com/lixilinx/psgd_torch).\n\nFor JAX version, see [psgd_jax](https://github.com/evanatyourservice/psgd_jax).\n\nImplementation of [PSGD Kron optimizer](https://github.com/lixilinx/psgd_torch) in PyTorch. \nPSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based \nor whitening-based (gg^T) preconditioner and lie groups to improve training convergence, \ngeneralization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked\nto above for interesting details on how PSGD works and experiments using PSGD.\n\n### `kron`:\n\nThe most versatile and easy-to-use PSGD optimizer is `kron`, which uses a \nKronecker-factored preconditioner. It has less hyperparameters that need tuning than adam, and can \nbe a drop-in replacement for adam. It keeps a dim's preconditioner as either triangular \nor diagonal based on `max_size_triangular` and `max_skew_triangular`. For example, for a layer \nwith shape (256, 128, 64), triangular preconditioners would be shapes (256, 256), (128, 128), and \n(64, 64) and diagonal preconditioners would be shapes (256,), (128,), and (64,). Depending on how \nthese two settings are chosen, `kron` can balance between memory/speed and performance (see below).\n\n\n## Installation\n\n```bash\npip install kron-torch\n```\n\n## Basic Usage (Kron)\n\nFor basic usage, use `kron` optimizer like any other pytorch optimizer:\n\n```python\nfrom kron_torch import Kron\n\noptimizer = Kron(params)\n\noptimizer.zero_grad()\nloss.backward()\noptimizer.step()\n```\n\n**Basic hyperparameters:**\n\nTLDR: Learning rate acts similarly to adam's, but can be set a little higher like 0.001 -> \n0.003. Weight decay should be set lower than adam's, like 0.1 -> 0.03 or 0.01. There is no\nb2 or epsilon.\n\n`learning_rate`: Kron's learning rate acts similarly to adam's, but can withstand a higher \nlearning rate. Try setting 3x higher. If 0.001 was best for adam, try setting kron's to 0.003.\n\n`weight_decay`: PSGD does not rely on weight decay for generalization as much as adam, and too\nhigh weight decay can hurt performance. Try setting 3-10x lower. If the best weight decay for \nadam was 0.1, you can set kron's to 0.03 or 0.01.\n\n`max_size_triangular`: Anything above this value will have a diagonal preconditioner, anything \nbelow will have a triangular preconditioner. So if you have a dim with size 16,384 that you want \nto use a diagonal preconditioner for, set `max_size_triangular` to something like 15,000. Default \nis 8192.\n\n`max_skew_triangular`: Any tensor with skew above this value with make the larger dim diagonal.\nFor example, with the default value for `max_skew_triangular` as 10, a bias layer of shape \n(256,) would be diagonal because 256/1 > 10, and an embedding dim of shape (50000, 768) would \nbe (diag, tri) because 50000/768 is greater than 10. The default value of 10 usually makes \nlayers like bias, scale, and vocab embedding use diagonal with the rest as triangular.\n\nInteresting note: Setting `max_skew_triangular` to 0 will make all layers have (diag, tri) \npreconditioners which uses slightly less memory than adam (and runs slightly faster). Setting \n`max_size_triangular` to 0 will make all layers have diagonal preconditioners which uses the least \nmemory and runs the fastest, but training might be slower.\n\n`preconditioner_update_probability`: Preconditioner update probability uses a schedule by default \nthat works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training \nwill be slightly slower at the start but will speed up to near adam's speed by around 3k steps.\n\nSee kron.py for more hyperparameter details.\n",
"bugtrack_url": null,
"license": null,
"summary": "An implementation of PSGD Kron optimizer in PyTorch.",
"version": "0.1.3",
"project_urls": {
"homepage": "https://github.com/evanatyourservice/kron_torch",
"repository": "https://github.com/evanatyourservice/kron_torch"
},
"split_keywords": [
"python",
" machine learning",
" optimization",
" pytorch"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "8142eefb75f0e144fcb7d912aa6b7e5e694e5e2fc0f109be4a7c3495c28487d8",
"md5": "af376108d6666226cc826f14b4325dbc",
"sha256": "7f152100834398c4c6377c34fb3ddc2b6b6e405626d852afbbd47a1a911c5d3e"
},
"downloads": -1,
"filename": "kron_torch-0.1.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "af376108d6666226cc826f14b4325dbc",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 15008,
"upload_time": "2024-10-01T19:32:35",
"upload_time_iso_8601": "2024-10-01T19:32:35.139880Z",
"url": "https://files.pythonhosted.org/packages/81/42/eefb75f0e144fcb7d912aa6b7e5e694e5e2fc0f109be4a7c3495c28487d8/kron_torch-0.1.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "853e9279765cd7818d1a01594628089a7f61a5fd354d5fd9140c1eaf7361b1a9",
"md5": "f75f17b63dc67feb921089be9e829dfc",
"sha256": "805cbef20153e757bad9ee5a832b77b972c83443e5e7cc3cb3ebe2a307451adf"
},
"downloads": -1,
"filename": "kron_torch-0.1.3.tar.gz",
"has_sig": false,
"md5_digest": "f75f17b63dc67feb921089be9e829dfc",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 15550,
"upload_time": "2024-10-01T19:32:37",
"upload_time_iso_8601": "2024-10-01T19:32:37.414009Z",
"url": "https://files.pythonhosted.org/packages/85/3e/9279765cd7818d1a01594628089a7f61a5fd354d5fd9140c1eaf7361b1a9/kron_torch-0.1.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-10-01 19:32:37",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "evanatyourservice",
"github_project": "kron_torch",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "kron-torch"
}