# L-BFGS optimizer written with JAX
## Features
- Implements the Limited-memory [BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm.
- JIT/vmap/pmap compatible for performance with JAX.
- Note requirements.txt is setup for JAX[CPU]
## Usage
Define a function to minimize
```python
def func(x):
jnp.sum((-1*coefficients + x)**2)
```
Call Lbfgs
-f: function to minimize
-m: number of previous iterations to store in memory
-tol: tolerance of convergence
```python
optimizer = Lbfgs(f=func, m=10, tol=1e-6)
```
iterate to find minimum
```python
# Initialize optimizer state
opt_state = optimizer.init(x0)
@jax.jit
def opt_step(carry, _):
opt_state, losses = carry
opt_state = optimizer.update(opt_state)
losses = losses.at[opt_state.k].set(loss(opt_state.position))
return (opt_state, losses), _
iterations=10000 #<-- A lot of iterations!!!
losses = jnp.zeros((iterations,))
(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)
#note losses will be the length of iterations
losses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))
```
output
```
[-7.577116e-15 1.000000e+00 2.000000e+00 3.000000e+00 4.000000e+00
5.000000e+00 6.000000e+00 7.000000e+00 8.000000e+00 9.000000e+00
1.000000e+01 1.100000e+01 1.200000e+01 1.300000e+01 1.400000e+01
1.500000e+01 1.600000e+01 1.700000e+01 1.800000e+01 1.900000e+01
2.000000e+01 2.100000e+01 2.200000e+01 2.300000e+01 2.400000e+01
2.500000e+01 2.600000e+01 2.700000e+01 2.800000e+01 2.900000e+01
3.000000e+01 3.100000e+01 3.200000e+01 3.300000e+01 3.400000e+01
3.500000e+01 3.600000e+01 3.700000e+01 3.800000e+01 3.900000e+01
4.000000e+01 4.100000e+01 4.200000e+01 4.300000e+01 4.400000e+01
4.500000e+01 4.600000e+01 4.700000e+01 4.800000e+01 4.900000e+01
5.000000e+01 5.100000e+01 5.200000e+01 5.300000e+01 5.400000e+01
5.500000e+01 5.600000e+01 5.700000e+01 5.800000e+01 5.900000e+01
6.000000e+01 6.100000e+01 6.200000e+01 6.300000e+01 6.400000e+01
6.500000e+01 6.600000e+01 6.700000e+01 6.800000e+01 6.900000e+01
7.000000e+01 7.100000e+01 7.200000e+01 7.300000e+01 7.400000e+01
7.500000e+01 7.600000e+01 7.700000e+01 7.800000e+01 7.900000e+01
8.000000e+01 8.100000e+01 8.200000e+01 8.300000e+01 8.400000e+01
8.500000e+01 8.600000e+01 8.700000e+01 8.800000e+01 8.900000e+01
9.000000e+01 9.100000e+01 9.200000e+01 9.300000e+01 9.400000e+01
9.500000e+01 9.600000e+01 9.700000e+01 9.800000e+01 9.900000e+01]
Function value at minimum: 5.7412694e-29
k: 2 #<-- stops early if gradient norm is less than tol!!
```
*NOTE*: Examples has the quadratic function and the Rosenbrock function.\
1000-dimensional Rosenbrock solved in 4038 steps
## Installation
Raw data
{
"_id": null,
"home_page": null,
"name": "GradientTransformation",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "jax, optimizer, lbfgs, optimization",
"author": null,
"author_email": "Joseph Schafer <joeschafer28@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/8d/a5/e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884/gradienttransformation-1.0.0.tar.gz",
"platform": null,
"description": "# L-BFGS optimizer written with JAX\n\n## Features\n\n- Implements the Limited-memory [BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm.\n- JIT/vmap/pmap compatible for performance with JAX.\n- Note requirements.txt is setup for JAX[CPU]\n\n## Usage\nDefine a function to minimize\n```python\ndef func(x): \n jnp.sum((-1*coefficients + x)**2)\n```\n\nCall Lbfgs\n-f: function to minimize\n-m: number of previous iterations to store in memory\n-tol: tolerance of convergence\n```python\noptimizer = Lbfgs(f=func, m=10, tol=1e-6)\n```\n\niterate to find minimum\n```python\n# Initialize optimizer state\nopt_state = optimizer.init(x0)\n\n@jax.jit\ndef opt_step(carry, _):\n opt_state, losses = carry\n opt_state = optimizer.update(opt_state)\n losses = losses.at[opt_state.k].set(loss(opt_state.position))\n return (opt_state, losses), _\n\niterations=10000 #<-- A lot of iterations!!!\nlosses = jnp.zeros((iterations,))\n(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)\n#note losses will be the length of iterations\nlosses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))\n```\n\noutput\n```\n[-7.577116e-15 1.000000e+00 2.000000e+00 3.000000e+00 4.000000e+00\n 5.000000e+00 6.000000e+00 7.000000e+00 8.000000e+00 9.000000e+00\n 1.000000e+01 1.100000e+01 1.200000e+01 1.300000e+01 1.400000e+01\n 1.500000e+01 1.600000e+01 1.700000e+01 1.800000e+01 1.900000e+01\n 2.000000e+01 2.100000e+01 2.200000e+01 2.300000e+01 2.400000e+01\n 2.500000e+01 2.600000e+01 2.700000e+01 2.800000e+01 2.900000e+01\n 3.000000e+01 3.100000e+01 3.200000e+01 3.300000e+01 3.400000e+01\n 3.500000e+01 3.600000e+01 3.700000e+01 3.800000e+01 3.900000e+01\n 4.000000e+01 4.100000e+01 4.200000e+01 4.300000e+01 4.400000e+01\n 4.500000e+01 4.600000e+01 4.700000e+01 4.800000e+01 4.900000e+01\n 5.000000e+01 5.100000e+01 5.200000e+01 5.300000e+01 5.400000e+01\n 5.500000e+01 5.600000e+01 5.700000e+01 5.800000e+01 5.900000e+01\n 6.000000e+01 6.100000e+01 6.200000e+01 6.300000e+01 6.400000e+01\n 6.500000e+01 6.600000e+01 6.700000e+01 6.800000e+01 6.900000e+01\n 7.000000e+01 7.100000e+01 7.200000e+01 7.300000e+01 7.400000e+01\n 7.500000e+01 7.600000e+01 7.700000e+01 7.800000e+01 7.900000e+01\n 8.000000e+01 8.100000e+01 8.200000e+01 8.300000e+01 8.400000e+01\n 8.500000e+01 8.600000e+01 8.700000e+01 8.800000e+01 8.900000e+01\n 9.000000e+01 9.100000e+01 9.200000e+01 9.300000e+01 9.400000e+01\n 9.500000e+01 9.600000e+01 9.700000e+01 9.800000e+01 9.900000e+01]\n\nFunction value at minimum: 5.7412694e-29\nk: 2 #<-- stops early if gradient norm is less than tol!!\n```\n\n*NOTE*: Examples has the quadratic function and the Rosenbrock function.\\\n1000-dimensional Rosenbrock solved in 4038 steps\n\n\n## Installation\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "A JAX-based L-BFGS optimizer",
"version": "1.0.0",
"project_urls": {
"Repository": "https://github.com/JWSch4fer/LBFGS_JAX"
},
"split_keywords": [
"jax",
" optimizer",
" lbfgs",
" optimization"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "9864ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5",
"md5": "c1d3ecc4f8d05a2a6d8a17dd11cfebcd",
"sha256": "876b80629cbf944204b7d51d7e3ec57c4b558b270f373f079c0327de34634f90"
},
"downloads": -1,
"filename": "GradientTransformation-1.0.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "c1d3ecc4f8d05a2a6d8a17dd11cfebcd",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 7732,
"upload_time": "2024-10-09T17:42:54",
"upload_time_iso_8601": "2024-10-09T17:42:54.585696Z",
"url": "https://files.pythonhosted.org/packages/98/64/ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5/GradientTransformation-1.0.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "8da5e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884",
"md5": "18b5d27181eff6a712525de6ea3de8e4",
"sha256": "1d1dde18a26d6f44f6090021d704c39a92bf06a44ef87bb9f32aac2b52ba1eef"
},
"downloads": -1,
"filename": "gradienttransformation-1.0.0.tar.gz",
"has_sig": false,
"md5_digest": "18b5d27181eff6a712525de6ea3de8e4",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 6954,
"upload_time": "2024-10-09T17:42:56",
"upload_time_iso_8601": "2024-10-09T17:42:56.300843Z",
"url": "https://files.pythonhosted.org/packages/8d/a5/e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884/gradienttransformation-1.0.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-10-09 17:42:56",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "JWSch4fer",
"github_project": "LBFGS_JAX",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [],
"lcname": "gradienttransformation"
}