psgd-jax


Namepsgd-jax JSON
Version 0.1.5 PyPI version JSON
download
home_pageNone
SummaryAn implementation of PSGD optimizer in JAX.
upload_time2024-08-22 16:15:27
maintainerNone
docs_urlNone
authorEvan Walters, Omead Pooladzandi, Xi-Lin Li
requires_python>=3.9
licenseNone
keywords python machine learning optimization jax
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # PSGD (Preconditioned Stochastic Gradient Descent)

Implementation of [PSGD optimizer](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). 
PSGD is a second-order optimizer originally created by Xi-Lin Li that uses a hessian-based 
preconditioner and lie groups to improve convergence, generalization, and efficiency.


## Installation

```bash
pip install psgd-jax
```

## Usage

PSGD defaults to a gradient whitening type preconditioner (gg^T). In this case, you can use PSGD 
like any other optax optimizer:

```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat  # or low_rank_approximation, affine


def loss_fn(params, x):
    return jnp.sum((params - x) ** 2)


params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])

# make optimizer and init state
opt = xmat(
    learning_rate=1.0,
    b1=0.0,
    preconditioner_update_probability=1.0,  # preconditioner update frequency
)
opt_state = opt.init(params)


def step(params, x, opt_state):
    loss_val, grad = jax.value_and_grad(loss_fn)(params, x)
    updates, opt_state = opt.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val


while True:
    params, opt_state, loss_val = step(params, x, opt_state)
    print(loss_val)
    if loss_val < 1e-4:
        print("yay")
        break

# Expected output:
# 14.0
# 5.1563816
# 1.7376599
# 0.6118454
# 0.18457186
# 0.056664664
# 0.014270116
# 0.0027846962
# 0.00018843572
# 4.3836744e-06
# yay
```

However, PSGD is best used with a hessian vector product. If values are provided for PSGD's extra 
update function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically 
uses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random 
vector used to calculate the hessian vector product, and `update_preconditioner` is a boolean 
that tells PSGD whether we're updating the preconditioner this step (passed in real hvp and 
vector) or not (passed in dummy hvp and vector).

The `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:

```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat  # or low_rank_approximation, affine
from psgd_jax import hessian_helper


def loss_fn(params, x):
    return jnp.sum((params - x) ** 2)


params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])

# make optimizer and init state
# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper
opt = xmat(
    learning_rate=1.0,
    b1=0.0,
)
opt_state = opt.init(params)


def step(key, params, x, opt_state):
    # replace jax.value_and_grad with the hessian_helper:
    key, subkey = jax.random.split(key)
    loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(
        subkey,
        loss_fn,
        params,
        loss_fn_extra_args=(x,),
        has_aux=False,
        preconditioner_update_probability=1.0,  # update frequency handled in hessian_helper
    )
    loss_val = loss_fn_out

    # Pass hvp, random vector, and whether we're updating the preconditioner 
    # this step into the update function. PSGD will automatically switch to 
    # hessian-based preconditioning when these are provided.
    updates, opt_state = opt.update(
        grad,
        opt_state,
        Hvp=hvp,
        vector=vector,
        update_preconditioner=update_precond
    )

    params = optax.apply_updates(params, updates)
    return key, params, opt_state, loss_val


key = jax.random.PRNGKey(0)
while True:
    key, params, opt_state, loss_val = step(key, params, x, opt_state)
    print(loss_val)
    if loss_val < 1e-4:
        print("yay")
        break

# Expected output:
# 14.0
# 7.460699e-14
# yay
```

If `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less 
often, but convergence could be slower.

## PSGD variants

`psgd_jax.xmat` `psgd_jax.low_rank_approximation` `psgd_jax.affine`

There are three variants of PSGD: XMat, which uses an x-shaped global preconditioner, LRA, which 
uses a low-rank approximation global preconditioner, and Affine, which uses block diagonal or 
diagonal preconditioners.

**XMat:**

XMat is very simple to use, uses global hessian information for its preconditioner, and has 
memory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).

**LRA:**

Low rank approximation uses a low rank hessian for its preconditioner and can give very strong 
results. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).

**Affine:**

Affine does not use global hessian information, but can be powerful nonetheless and possibly use 
less memory than xmat or LRA. `max_size_triangular` and `max_skew_triangular` determine whether 
a dimension's preconditioner is either block diagonal or diagonal.

For example, if `max_size_triangular` is set to 512 and a layer's is shape (1024, 16, 64), the 
preconditioner shapes will be [diag, block_diag, block_diag] or [(1024,), (16, 16), (64, 64)] 
because 1024 > 512.

If `max_skew_triangular` is set to 32 and a layer's shape is (1024, 3), 
the preconditioner shapes will be [diag, block_diag] or [(1024,), (3, 3)] because 1024/3 is 
greater than 32.

If `max_size_triangular` and `max_skew_triangular` are set to 0, the affine preconditioners
will be entirely diagonal and would use less memory than adam even with momentum.


## Notes on sharding:

For now PSGD does not explicitly handle any sharding, so intermediates would be handled naively by 
JAX based on how users define in and out shardings. Our goal is to improve preconditioner shapes 
and explicitly handle sharding for PSGD, especially for XMat and LRA, to make it more efficient
in distributed settings.

**Optimizer state shapes:**

Momentum is always same shape as params.

Affine might be the most out-of-the-box sharding friendly as it uses block diagonal or diagonal 
preconditioners. For example, if a layer has shape (1024, 16, 64) and `max_size_triangular` is set 
to 512, the preconditioner shapes will be `[(1024,), (16, 16), (64, 64)]`, which could be sharded as 
the user sees fit.

XMat's preconditioners `a` and `b` are both of shape `(n_params,)`. If n_params is odd, or not divisible 
by number of devices, dummy params could be added before optimizer init and update.

LRA's preconditioner shapes are `U=(n_params, rank)`, `V=(n_params, rank)`, and `d=(n_params, 1)`.


## Resources

PSGD papers and resources listed from Xi-Lin's repo

1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)
2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)
3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)
4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.)
5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)


## License

[![CC BY 4.0][cc-by-image]][cc-by]

This work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].

2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li


[cc-by]: http://creativecommons.org/licenses/by/4.0/
[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png
[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg


            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "psgd-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "python, machine learning, optimization, jax",
    "author": "Evan Walters, Omead Pooladzandi, Xi-Lin Li",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/ea/ea/fb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72/psgd_jax-0.1.5.tar.gz",
    "platform": null,
    "description": "# PSGD (Preconditioned Stochastic Gradient Descent)\n\nImplementation of [PSGD optimizer](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). \nPSGD is a second-order optimizer originally created by Xi-Lin Li that uses a hessian-based \npreconditioner and lie groups to improve convergence, generalization, and efficiency.\n\n\n## Installation\n\n```bash\npip install psgd-jax\n```\n\n## Usage\n\nPSGD defaults to a gradient whitening type preconditioner (gg^T). In this case, you can use PSGD \nlike any other optax optimizer:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat  # or low_rank_approximation, affine\n\n\ndef loss_fn(params, x):\n    return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\nopt = xmat(\n    learning_rate=1.0,\n    b1=0.0,\n    preconditioner_update_probability=1.0,  # preconditioner update frequency\n)\nopt_state = opt.init(params)\n\n\ndef step(params, x, opt_state):\n    loss_val, grad = jax.value_and_grad(loss_fn)(params, x)\n    updates, opt_state = opt.update(grad, opt_state)\n    params = optax.apply_updates(params, updates)\n    return params, opt_state, loss_val\n\n\nwhile True:\n    params, opt_state, loss_val = step(params, x, opt_state)\n    print(loss_val)\n    if loss_val < 1e-4:\n        print(\"yay\")\n        break\n\n# Expected output:\n# 14.0\n# 5.1563816\n# 1.7376599\n# 0.6118454\n# 0.18457186\n# 0.056664664\n# 0.014270116\n# 0.0027846962\n# 0.00018843572\n# 4.3836744e-06\n# yay\n```\n\nHowever, PSGD is best used with a hessian vector product. If values are provided for PSGD's extra \nupdate function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically \nuses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random \nvector used to calculate the hessian vector product, and `update_preconditioner` is a boolean \nthat tells PSGD whether we're updating the preconditioner this step (passed in real hvp and \nvector) or not (passed in dummy hvp and vector).\n\nThe `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat  # or low_rank_approximation, affine\nfrom psgd_jax import hessian_helper\n\n\ndef loss_fn(params, x):\n    return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\n# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper\nopt = xmat(\n    learning_rate=1.0,\n    b1=0.0,\n)\nopt_state = opt.init(params)\n\n\ndef step(key, params, x, opt_state):\n    # replace jax.value_and_grad with the hessian_helper:\n    key, subkey = jax.random.split(key)\n    loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(\n        subkey,\n        loss_fn,\n        params,\n        loss_fn_extra_args=(x,),\n        has_aux=False,\n        preconditioner_update_probability=1.0,  # update frequency handled in hessian_helper\n    )\n    loss_val = loss_fn_out\n\n    # Pass hvp, random vector, and whether we're updating the preconditioner \n    # this step into the update function. PSGD will automatically switch to \n    # hessian-based preconditioning when these are provided.\n    updates, opt_state = opt.update(\n        grad,\n        opt_state,\n        Hvp=hvp,\n        vector=vector,\n        update_preconditioner=update_precond\n    )\n\n    params = optax.apply_updates(params, updates)\n    return key, params, opt_state, loss_val\n\n\nkey = jax.random.PRNGKey(0)\nwhile True:\n    key, params, opt_state, loss_val = step(key, params, x, opt_state)\n    print(loss_val)\n    if loss_val < 1e-4:\n        print(\"yay\")\n        break\n\n# Expected output:\n# 14.0\n# 7.460699e-14\n# yay\n```\n\nIf `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less \noften, but convergence could be slower.\n\n## PSGD variants\n\n`psgd_jax.xmat` `psgd_jax.low_rank_approximation` `psgd_jax.affine`\n\nThere are three variants of PSGD: XMat, which uses an x-shaped global preconditioner, LRA, which \nuses a low-rank approximation global preconditioner, and Affine, which uses block diagonal or \ndiagonal preconditioners.\n\n**XMat:**\n\nXMat is very simple to use, uses global hessian information for its preconditioner, and has \nmemory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).\n\n**LRA:**\n\nLow rank approximation uses a low rank hessian for its preconditioner and can give very strong \nresults. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).\n\n**Affine:**\n\nAffine does not use global hessian information, but can be powerful nonetheless and possibly use \nless memory than xmat or LRA. `max_size_triangular` and `max_skew_triangular` determine whether \na dimension's preconditioner is either block diagonal or diagonal.\n\nFor example, if `max_size_triangular` is set to 512 and a layer's is shape (1024, 16, 64), the \npreconditioner shapes will be [diag, block_diag, block_diag] or [(1024,), (16, 16), (64, 64)] \nbecause 1024 > 512.\n\nIf `max_skew_triangular` is set to 32 and a layer's shape is (1024, 3), \nthe preconditioner shapes will be [diag, block_diag] or [(1024,), (3, 3)] because 1024/3 is \ngreater than 32.\n\nIf `max_size_triangular` and `max_skew_triangular` are set to 0, the affine preconditioners\nwill be entirely diagonal and would use less memory than adam even with momentum.\n\n\n## Notes on sharding:\n\nFor now PSGD does not explicitly handle any sharding, so intermediates would be handled naively by \nJAX based on how users define in and out shardings. Our goal is to improve preconditioner shapes \nand explicitly handle sharding for PSGD, especially for XMat and LRA, to make it more efficient\nin distributed settings.\n\n**Optimizer state shapes:**\n\nMomentum is always same shape as params.\n\nAffine might be the most out-of-the-box sharding friendly as it uses block diagonal or diagonal \npreconditioners. For example, if a layer has shape (1024, 16, 64) and `max_size_triangular` is set \nto 512, the preconditioner shapes will be `[(1024,), (16, 16), (64, 64)]`, which could be sharded as \nthe user sees fit.\n\nXMat's preconditioners `a` and `b` are both of shape `(n_params,)`. If n_params is odd, or not divisible \nby number of devices, dummy params could be added before optimizer init and update.\n\nLRA's preconditioner shapes are `U=(n_params, rank)`, `V=(n_params, rank)`, and `d=(n_params, 1)`.\n\n\n## Resources\n\nPSGD papers and resources listed from Xi-Lin's repo\n\n1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)\n2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)\n3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)\n4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\\rm GL}(n, \\mathbb{R})/R_{\\rm polar}$.)\n5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)\n\n\n## License\n\n[![CC BY 4.0][cc-by-image]][cc-by]\n\nThis work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].\n\n2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li\n\n\n[cc-by]: http://creativecommons.org/licenses/by/4.0/\n[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png\n[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg\n\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "An implementation of PSGD optimizer in JAX.",
    "version": "0.1.5",
    "project_urls": {
        "homepage": "https://github.com/evanatyourservice/psgd_jax",
        "repository": "https://github.com/evanatyourservice/psgd_jax"
    },
    "split_keywords": [
        "python",
        " machine learning",
        " optimization",
        " jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8c6535a29aee64e4e3b79f95361854c089048d24b7782133dcdad561a7bb7283",
                "md5": "64756d3ddb8c1990c2d136a402a07a1b",
                "sha256": "7a07ccc4f68ecf572c41a1cd25f8f082fcb2073c2adad8fce3d58bbdac4c0d4b"
            },
            "downloads": -1,
            "filename": "psgd_jax-0.1.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "64756d3ddb8c1990c2d136a402a07a1b",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 27183,
            "upload_time": "2024-08-22T16:15:25",
            "upload_time_iso_8601": "2024-08-22T16:15:25.859422Z",
            "url": "https://files.pythonhosted.org/packages/8c/65/35a29aee64e4e3b79f95361854c089048d24b7782133dcdad561a7bb7283/psgd_jax-0.1.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "eaeafb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72",
                "md5": "1f96c042ade2c37e8cc111079bb810db",
                "sha256": "1d4b5dbd3678e54a75fe8dd698926816205d12ebe3f76c974486eb6d574e4fcc"
            },
            "downloads": -1,
            "filename": "psgd_jax-0.1.5.tar.gz",
            "has_sig": false,
            "md5_digest": "1f96c042ade2c37e8cc111079bb810db",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 23947,
            "upload_time": "2024-08-22T16:15:27",
            "upload_time_iso_8601": "2024-08-22T16:15:27.275024Z",
            "url": "https://files.pythonhosted.org/packages/ea/ea/fb89f74bfddd7195a0948bdb449a0e940648790fd8440a29e6a91b6a2f72/psgd_jax-0.1.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-22 16:15:27",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "evanatyourservice",
    "github_project": "psgd_jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "psgd-jax"
}
        
Elapsed time: 1.18981s