psgd-jax


Namepsgd-jax JSON
Version 0.2.8 PyPI version JSON
download
home_pageNone
SummaryAn implementation of PSGD optimizer in JAX.
upload_time2024-12-08 20:46:32
maintainerNone
docs_urlNone
authorEvan Walters, Omead Pooladzandi, Xi-Lin Li
requires_python>=3.9
licenseNone
keywords python machine learning optimization jax
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # PSGD (Preconditioned Stochastic Gradient Descent)

For original PSGD repo, see [psgd_torch](https://github.com/lixilinx/psgd_torch).

For PyTorch Kron version, see [kron_torch](https://github.com/evanatyourservice/kron_torch).

Implementations of [PSGD optimizers](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). 
PSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based 
or whitening-based (gg^T) preconditioner and lie groups to improve training convergence, 
generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked
to above for interesting details on how PSGD works and experiments using PSGD. There are also 
paper resources listed near the bottom of this readme.

### `kron`:

The most versatile and easy-to-use PSGD optimizer is `kron`, which uses a Kronecker-factored 
preconditioner. It has less hyperparameters that need tuning than adam, and can generally act as a 
drop-in replacement.

## Installation

```bash
pip install psgd-jax
```

## Basic Usage (Kron)

Kron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 
at the beginning of training, so training will be slightly slower at the start but will speed up 
by around 4k steps.

For basic usage, use `kron` optimizer like any other optax optimizer:

```python
from psgd_jax.kron import kron

optimizer = kron()
opt_state = optimizer.init(params)

updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
```

**Basic hyperparameters:**

TLDR: Learning rate and weight decay act similarly to adam's, start with adam-like settings and go 
from there. Maybe use slightly lower learning rate (like /2). There is no b2 or epsilon.

These next 3 settings control whether a dimension's preconditioner is diagonal or triangular. 
For example, for a layer with shape (256, 128), triagular preconditioners would be shapes (256, 256)
and (128, 128), and diagonal preconditioners would be shapes (256,) and (128,). Depending on how 
these settings are chosen, `kron` can balance between memory/speed and effectiveness. Defaults lead
to most precoditioners being triangular except for 1-dimensional layers and very large dimensions.

`max_size_triangular`: Any dimension with size above this value will have a diagonal preconditioner.

`min_ndim_triangular`: Any tensor with less than this number of dims will have all diagonal 
preconditioners. Default is 2, so single-dim layers like bias and scale will use diagonal
preconditioners.

`memory_save_mode`: Can be None, 'one_diag', or 'all_diag'. None is default and lets all 
preconditioners be triangular. 'one_diag' sets the largest or last dim per layer as diagonal 
using `np.argsort(shape)[::-1][0]`. 'all_diag' sets all preconditioners to be diagonal.

`preconditioner_update_probability`: Preconditioner update probability uses a schedule by default 
that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training 
will be slightly slower at the start but will speed up by around 4k steps. PSGD generally benefits
from more preconditioner updates at the start of training, but once the preconditioner is learned 
it's okay to do them less often. An easy way to adjust update frequency is to define your own schedule
using the `precond_update_prob_schedule` function in kron.py (just changing the `min_prob` value 
is easiest) and pass this into kron through the `preconditioner_update_probability` hyperparameter.

This is the default schedule defined in the `precond_update_prob_schedule` function at the top of kron.py:

<img src="assets/default_schedule.png" alt="Default Schedule" width="800" style="max-width: 100%; height: auto;" />


**Sharding:**

Kron contains einsums, and in general the first axis of the preconditioner matrices is the 
contracting axis.

If using only FSDP, I usually shard the last axis of each preconditioner matrix and call it good.

However, if using tensor parallelism in addition to FSDP, you may think more carefully about how 
the preconditioners are sharded in train_state. For example, with grads of shape (256, 128) and kron 
preconditioners of shapes (256, 256) and (128, 128), if the grads are sharded as (fsdp, tensor), 
then you may want to shard the (256, 256) preconditioner as (fsdp, tensor) and the (128, 128) 
preconditioner as (tensor, fsdp) so the grads and its preconditioners have similar contracting axes.


**Scanned layers:**

If you are scanning layers in your network, you can also have kron scan over these layers while 
updating and applying the preconditioner. Simply pass in a pytree through `scanned_layers` with 
the same structure as your params with bool values indicating which layers are scanned. PSGD will 
vmap over the first dims of those layers. If you need a more advanced scanning setup, please open 
an issue.

For very large models, the preconditioner update may use too much memory all at once when scanning, 
in which case you can set `lax_map_scanned_layers` to `True` and set `lax_map_batch_size` to a 
reasonable batch size for your setup (`lax.map` scans over batches of vmap, see JAX docs). If 
your net is 32 layers and you're hitting OOM during the optimizer step, you can break the model into
2 or 4 and set `lax_map_batch_size` to 16 or 8 respectively.


## Advanced Usage (XMat, LRA, Affine)

Other forms of PSGD include XMat, LRA, and Affine. PSGD defaults to a gradient 
whitening type preconditioner (gg^T). In this case, you can use PSGD like any other 
optax optimizer:

```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat  # or low_rank_approximation, affine


def loss_fn(params, x):
    return jnp.sum((params - x) ** 2)


params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])

# make optimizer and init state
opt = xmat(
    learning_rate=1.0,
    b1=0.0,
    preconditioner_update_probability=1.0,  # preconditioner update frequency
)
opt_state = opt.init(params)


def step(params, x, opt_state):
    loss_val, grad = jax.value_and_grad(loss_fn)(params, x)
    updates, opt_state = opt.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val


while True:
    params, opt_state, loss_val = step(params, x, opt_state)
    print(loss_val)
    if loss_val < 1e-4:
        print("yay")
        break

# Expected output:
# 14.0
# 5.1563816
# 1.7376599
# 0.6118454
# 0.18457186
# 0.056664664
# 0.014270116
# 0.0027846962
# 0.00018843572
# 4.3836744e-06
# yay
```

However, PSGD can also be used with a hessian vector product. If values are provided for PSGD's extra 
update function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically 
uses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random 
vector used to calculate the hessian vector product, and `update_preconditioner` is a boolean 
that tells PSGD whether we're updating the preconditioner this step (passed in real hvp and 
vector) or not (passed in dummy hvp and vector).

The `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:

```python
import jax
import jax.numpy as jnp
import optax
from psgd_jax.xmat import xmat  # or low_rank_approximation, affine
from psgd_jax import hessian_helper


def loss_fn(params, x):
    return jnp.sum((params - x) ** 2)


params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])

# make optimizer and init state
# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper
opt = xmat(
    learning_rate=1.0,
    b1=0.0,
)
opt_state = opt.init(params)


def step(key, params, x, opt_state):
    # replace jax.value_and_grad with the hessian_helper:
    key, subkey = jax.random.split(key)
    loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(
        subkey,
        loss_fn,
        params,
        loss_fn_extra_args=(x,),
        has_aux=False,
        preconditioner_update_probability=1.0,  # update frequency handled in hessian_helper
    )
    loss_val = loss_fn_out

    # Pass hvp, random vector, and whether we're updating the preconditioner 
    # this step into the update function. PSGD will automatically switch to 
    # hessian-based preconditioning when these are provided.
    updates, opt_state = opt.update(
        grad,
        opt_state,
        Hvp=hvp,
        vector=vector,
        update_preconditioner=update_precond
    )

    params = optax.apply_updates(params, updates)
    return key, params, opt_state, loss_val


key = jax.random.PRNGKey(0)
while True:
    key, params, opt_state, loss_val = step(key, params, x, opt_state)
    print(loss_val)
    if loss_val < 1e-4:
        print("yay")
        break

# Expected output:
# 14.0
# 7.460699e-14
# yay
```

If `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less 
often, but convergence could be slower.

## PSGD variants

`psgd_jax.kron` - `psgd_jax.xmat` - `psgd_jax.low_rank_approximation` - `psgd_jax.affine`

There are four variants of PSGD: Kron, which uses Kronecker-factored preconditioners for tensors
of any number of dimensions, XMat, which uses an x-shaped global preconditioner, LRA, which uses 
a low-rank approximation global preconditioner, and Affine, which uses kronecker-factored 
preconditioners for matrices.

**Kron:**

Kron uses Kronecker-factored preconditioners for tensors of any number of dimensions. It's very 
versatile, has less hyperparameters that need tuning than adam, and can generally act as a drop-in 
replacement for adam.

**XMat:**

XMat is very simple to use, uses global hessian information for its preconditioner, and has 
memory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).

**LRA:**

Low rank approximation uses a low rank hessian for its preconditioner and can give very strong 
results. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).

**Affine:**

Affine does not use global hessian information, but can be powerful nonetheless and possibly use 
less memory than xmat, LRA, or adam. `max_size_triangular` and `max_skew_triangular` determine whether 
a dimension's preconditioner is triangular or diagonal. Affine and Kron are nearly identical for matrices.


## Resources

PSGD papers and resources listed from Xi-Lin's repo

1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)
2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)
3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)
4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.)
5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)


## License

[![CC BY 4.0][cc-by-image]][cc-by]

This work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].

2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li


[cc-by]: http://creativecommons.org/licenses/by/4.0/
[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png
[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg


            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "psgd-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "python, machine learning, optimization, jax",
    "author": "Evan Walters, Omead Pooladzandi, Xi-Lin Li",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/e3/e7/d4da1f5049dd6eb421fb81af623289e4717053dc4a96c3dfc04e0dab3a77/psgd_jax-0.2.8.tar.gz",
    "platform": null,
    "description": "# PSGD (Preconditioned Stochastic Gradient Descent)\n\nFor original PSGD repo, see [psgd_torch](https://github.com/lixilinx/psgd_torch).\n\nFor PyTorch Kron version, see [kron_torch](https://github.com/evanatyourservice/kron_torch).\n\nImplementations of [PSGD optimizers](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). \nPSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based \nor whitening-based (gg^T) preconditioner and lie groups to improve training convergence, \ngeneralization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked\nto above for interesting details on how PSGD works and experiments using PSGD. There are also \npaper resources listed near the bottom of this readme.\n\n### `kron`:\n\nThe most versatile and easy-to-use PSGD optimizer is `kron`, which uses a Kronecker-factored \npreconditioner. It has less hyperparameters that need tuning than adam, and can generally act as a \ndrop-in replacement.\n\n## Installation\n\n```bash\npip install psgd-jax\n```\n\n## Basic Usage (Kron)\n\nKron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 \nat the beginning of training, so training will be slightly slower at the start but will speed up \nby around 4k steps.\n\nFor basic usage, use `kron` optimizer like any other optax optimizer:\n\n```python\nfrom psgd_jax.kron import kron\n\noptimizer = kron()\nopt_state = optimizer.init(params)\n\nupdates, opt_state = optimizer.update(grads, opt_state)\nparams = optax.apply_updates(params, updates)\n```\n\n**Basic hyperparameters:**\n\nTLDR: Learning rate and weight decay act similarly to adam's, start with adam-like settings and go \nfrom there. Maybe use slightly lower learning rate (like /2). There is no b2 or epsilon.\n\nThese next 3 settings control whether a dimension's preconditioner is diagonal or triangular. \nFor example, for a layer with shape (256, 128), triagular preconditioners would be shapes (256, 256)\nand (128, 128), and diagonal preconditioners would be shapes (256,) and (128,). Depending on how \nthese settings are chosen, `kron` can balance between memory/speed and effectiveness. Defaults lead\nto most precoditioners being triangular except for 1-dimensional layers and very large dimensions.\n\n`max_size_triangular`: Any dimension with size above this value will have a diagonal preconditioner.\n\n`min_ndim_triangular`: Any tensor with less than this number of dims will have all diagonal \npreconditioners. Default is 2, so single-dim layers like bias and scale will use diagonal\npreconditioners.\n\n`memory_save_mode`: Can be None, 'one_diag', or 'all_diag'. None is default and lets all \npreconditioners be triangular. 'one_diag' sets the largest or last dim per layer as diagonal \nusing `np.argsort(shape)[::-1][0]`. 'all_diag' sets all preconditioners to be diagonal.\n\n`preconditioner_update_probability`: Preconditioner update probability uses a schedule by default \nthat works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training \nwill be slightly slower at the start but will speed up by around 4k steps. PSGD generally benefits\nfrom more preconditioner updates at the start of training, but once the preconditioner is learned \nit's okay to do them less often. An easy way to adjust update frequency is to define your own schedule\nusing the `precond_update_prob_schedule` function in kron.py (just changing the `min_prob` value \nis easiest) and pass this into kron through the `preconditioner_update_probability` hyperparameter.\n\nThis is the default schedule defined in the `precond_update_prob_schedule` function at the top of kron.py:\n\n<img src=\"assets/default_schedule.png\" alt=\"Default Schedule\" width=\"800\" style=\"max-width: 100%; height: auto;\" />\n\n\n**Sharding:**\n\nKron contains einsums, and in general the first axis of the preconditioner matrices is the \ncontracting axis.\n\nIf using only FSDP, I usually shard the last axis of each preconditioner matrix and call it good.\n\nHowever, if using tensor parallelism in addition to FSDP, you may think more carefully about how \nthe preconditioners are sharded in train_state. For example, with grads of shape (256, 128) and kron \npreconditioners of shapes (256, 256) and (128, 128), if the grads are sharded as (fsdp, tensor), \nthen you may want to shard the (256, 256) preconditioner as (fsdp, tensor) and the (128, 128) \npreconditioner as (tensor, fsdp) so the grads and its preconditioners have similar contracting axes.\n\n\n**Scanned layers:**\n\nIf you are scanning layers in your network, you can also have kron scan over these layers while \nupdating and applying the preconditioner. Simply pass in a pytree through `scanned_layers` with \nthe same structure as your params with bool values indicating which layers are scanned. PSGD will \nvmap over the first dims of those layers. If you need a more advanced scanning setup, please open \nan issue.\n\nFor very large models, the preconditioner update may use too much memory all at once when scanning, \nin which case you can set `lax_map_scanned_layers` to `True` and set `lax_map_batch_size` to a \nreasonable batch size for your setup (`lax.map` scans over batches of vmap, see JAX docs). If \nyour net is 32 layers and you're hitting OOM during the optimizer step, you can break the model into\n2 or 4 and set `lax_map_batch_size` to 16 or 8 respectively.\n\n\n## Advanced Usage (XMat, LRA, Affine)\n\nOther forms of PSGD include XMat, LRA, and Affine. PSGD defaults to a gradient \nwhitening type preconditioner (gg^T). In this case, you can use PSGD like any other \noptax optimizer:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat  # or low_rank_approximation, affine\n\n\ndef loss_fn(params, x):\n    return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\nopt = xmat(\n    learning_rate=1.0,\n    b1=0.0,\n    preconditioner_update_probability=1.0,  # preconditioner update frequency\n)\nopt_state = opt.init(params)\n\n\ndef step(params, x, opt_state):\n    loss_val, grad = jax.value_and_grad(loss_fn)(params, x)\n    updates, opt_state = opt.update(grad, opt_state)\n    params = optax.apply_updates(params, updates)\n    return params, opt_state, loss_val\n\n\nwhile True:\n    params, opt_state, loss_val = step(params, x, opt_state)\n    print(loss_val)\n    if loss_val < 1e-4:\n        print(\"yay\")\n        break\n\n# Expected output:\n# 14.0\n# 5.1563816\n# 1.7376599\n# 0.6118454\n# 0.18457186\n# 0.056664664\n# 0.014270116\n# 0.0027846962\n# 0.00018843572\n# 4.3836744e-06\n# yay\n```\n\nHowever, PSGD can also be used with a hessian vector product. If values are provided for PSGD's extra \nupdate function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically \nuses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random \nvector used to calculate the hessian vector product, and `update_preconditioner` is a boolean \nthat tells PSGD whether we're updating the preconditioner this step (passed in real hvp and \nvector) or not (passed in dummy hvp and vector).\n\nThe `hessian_helper` function can help with this and generally replace `jax.value_and_grad`:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom psgd_jax.xmat import xmat  # or low_rank_approximation, affine\nfrom psgd_jax import hessian_helper\n\n\ndef loss_fn(params, x):\n    return jnp.sum((params - x) ** 2)\n\n\nparams = jnp.array([1.0, 2.0, 3.0])\nx = jnp.array([0.0, 0.0, 0.0])\n\n# make optimizer and init state\n# no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper\nopt = xmat(\n    learning_rate=1.0,\n    b1=0.0,\n)\nopt_state = opt.init(params)\n\n\ndef step(key, params, x, opt_state):\n    # replace jax.value_and_grad with the hessian_helper:\n    key, subkey = jax.random.split(key)\n    loss_fn_out, grad, hvp, vector, update_precond = hessian_helper(\n        subkey,\n        loss_fn,\n        params,\n        loss_fn_extra_args=(x,),\n        has_aux=False,\n        preconditioner_update_probability=1.0,  # update frequency handled in hessian_helper\n    )\n    loss_val = loss_fn_out\n\n    # Pass hvp, random vector, and whether we're updating the preconditioner \n    # this step into the update function. PSGD will automatically switch to \n    # hessian-based preconditioning when these are provided.\n    updates, opt_state = opt.update(\n        grad,\n        opt_state,\n        Hvp=hvp,\n        vector=vector,\n        update_preconditioner=update_precond\n    )\n\n    params = optax.apply_updates(params, updates)\n    return key, params, opt_state, loss_val\n\n\nkey = jax.random.PRNGKey(0)\nwhile True:\n    key, params, opt_state, loss_val = step(key, params, x, opt_state)\n    print(loss_val)\n    if loss_val < 1e-4:\n        print(\"yay\")\n        break\n\n# Expected output:\n# 14.0\n# 7.460699e-14\n# yay\n```\n\nIf `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less \noften, but convergence could be slower.\n\n## PSGD variants\n\n`psgd_jax.kron` - `psgd_jax.xmat` - `psgd_jax.low_rank_approximation` - `psgd_jax.affine`\n\nThere are four variants of PSGD: Kron, which uses Kronecker-factored preconditioners for tensors\nof any number of dimensions, XMat, which uses an x-shaped global preconditioner, LRA, which uses \na low-rank approximation global preconditioner, and Affine, which uses kronecker-factored \npreconditioners for matrices.\n\n**Kron:**\n\nKron uses Kronecker-factored preconditioners for tensors of any number of dimensions. It's very \nversatile, has less hyperparameters that need tuning than adam, and can generally act as a drop-in \nreplacement for adam.\n\n**XMat:**\n\nXMat is very simple to use, uses global hessian information for its preconditioner, and has \nmemory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable).\n\n**LRA:**\n\nLow rank approximation uses a low rank hessian for its preconditioner and can give very strong \nresults. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum).\n\n**Affine:**\n\nAffine does not use global hessian information, but can be powerful nonetheless and possibly use \nless memory than xmat, LRA, or adam. `max_size_triangular` and `max_skew_triangular` determine whether \na dimension's preconditioner is triangular or diagonal. Affine and Kron are nearly identical for matrices.\n\n\n## Resources\n\nPSGD papers and resources listed from Xi-Lin's repo\n\n1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)\n2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.)\n3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.)\n4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\\rm GL}(n, \\mathbb{R})/R_{\\rm polar}$.)\n5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)\n\n\n## License\n\n[![CC BY 4.0][cc-by-image]][cc-by]\n\nThis work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by].\n\n2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li\n\n\n[cc-by]: http://creativecommons.org/licenses/by/4.0/\n[cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png\n[cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg\n\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "An implementation of PSGD optimizer in JAX.",
    "version": "0.2.8",
    "project_urls": {
        "homepage": "https://github.com/evanatyourservice/psgd_jax",
        "repository": "https://github.com/evanatyourservice/psgd_jax"
    },
    "split_keywords": [
        "python",
        " machine learning",
        " optimization",
        " jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8b5f8ac928edcb96e529dc4d354d645e174e5798827dceab6918289795ec0957",
                "md5": "cd02e833fc2ecce5419fea083f949bd7",
                "sha256": "198a0f0197630b2524913d1c83b7037adb10bca8da8e4b384fcb349c2c7db472"
            },
            "downloads": -1,
            "filename": "psgd_jax-0.2.8-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "cd02e833fc2ecce5419fea083f949bd7",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 34960,
            "upload_time": "2024-12-08T20:46:31",
            "upload_time_iso_8601": "2024-12-08T20:46:31.182661Z",
            "url": "https://files.pythonhosted.org/packages/8b/5f/8ac928edcb96e529dc4d354d645e174e5798827dceab6918289795ec0957/psgd_jax-0.2.8-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "e3e7d4da1f5049dd6eb421fb81af623289e4717053dc4a96c3dfc04e0dab3a77",
                "md5": "d29cef17c884edf5980e9768dcc67c28",
                "sha256": "62804d16bac7c3ae4cf655c17744882fe0394ef29fbf64a9d99fd6320854b989"
            },
            "downloads": -1,
            "filename": "psgd_jax-0.2.8.tar.gz",
            "has_sig": false,
            "md5_digest": "d29cef17c884edf5980e9768dcc67c28",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 33081,
            "upload_time": "2024-12-08T20:46:32",
            "upload_time_iso_8601": "2024-12-08T20:46:32.216935Z",
            "url": "https://files.pythonhosted.org/packages/e3/e7/d4da1f5049dd6eb421fb81af623289e4717053dc4a96c3dfc04e0dab3a77/psgd_jax-0.2.8.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-12-08 20:46:32",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "evanatyourservice",
    "github_project": "psgd_jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "psgd-jax"
}
        
Elapsed time: 0.49921s