GradientTransformation


NameGradientTransformation JSON
Version 1.0.0 PyPI version JSON
download
home_pageNone
SummaryA JAX-based L-BFGS optimizer
upload_time2024-10-09 17:42:56
maintainerNone
docs_urlNone
authorNone
requires_python>=3.10
licenseMIT
keywords jax optimizer lbfgs optimization
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # L-BFGS optimizer written with JAX

## Features

- Implements the Limited-memory [BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm.
- JIT/vmap/pmap compatible for performance with JAX.
- Note requirements.txt is setup for JAX[CPU]

## Usage
Define a function to minimize
```python
def func(x): 
    jnp.sum((-1*coefficients + x)**2)
```

Call Lbfgs
-f: function to minimize
-m: number of previous iterations to store in memory
-tol: tolerance of convergence
```python
optimizer = Lbfgs(f=func, m=10, tol=1e-6)
```

iterate to find minimum
```python
# Initialize optimizer state
opt_state = optimizer.init(x0)

@jax.jit
def opt_step(carry, _):
    opt_state, losses = carry
    opt_state = optimizer.update(opt_state)
    losses = losses.at[opt_state.k].set(loss(opt_state.position))
    return (opt_state, losses), _

iterations=10000   #<-- A lot of iterations!!!
losses = jnp.zeros((iterations,))
(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)
#note losses will be the length of iterations
losses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))
```

output
```
[-7.577116e-15  1.000000e+00  2.000000e+00  3.000000e+00  4.000000e+00
  5.000000e+00  6.000000e+00  7.000000e+00  8.000000e+00  9.000000e+00
  1.000000e+01  1.100000e+01  1.200000e+01  1.300000e+01  1.400000e+01
  1.500000e+01  1.600000e+01  1.700000e+01  1.800000e+01  1.900000e+01
  2.000000e+01  2.100000e+01  2.200000e+01  2.300000e+01  2.400000e+01
  2.500000e+01  2.600000e+01  2.700000e+01  2.800000e+01  2.900000e+01
  3.000000e+01  3.100000e+01  3.200000e+01  3.300000e+01  3.400000e+01
  3.500000e+01  3.600000e+01  3.700000e+01  3.800000e+01  3.900000e+01
  4.000000e+01  4.100000e+01  4.200000e+01  4.300000e+01  4.400000e+01
  4.500000e+01  4.600000e+01  4.700000e+01  4.800000e+01  4.900000e+01
  5.000000e+01  5.100000e+01  5.200000e+01  5.300000e+01  5.400000e+01
  5.500000e+01  5.600000e+01  5.700000e+01  5.800000e+01  5.900000e+01
  6.000000e+01  6.100000e+01  6.200000e+01  6.300000e+01  6.400000e+01
  6.500000e+01  6.600000e+01  6.700000e+01  6.800000e+01  6.900000e+01
  7.000000e+01  7.100000e+01  7.200000e+01  7.300000e+01  7.400000e+01
  7.500000e+01  7.600000e+01  7.700000e+01  7.800000e+01  7.900000e+01
  8.000000e+01  8.100000e+01  8.200000e+01  8.300000e+01  8.400000e+01
  8.500000e+01  8.600000e+01  8.700000e+01  8.800000e+01  8.900000e+01
  9.000000e+01  9.100000e+01  9.200000e+01  9.300000e+01  9.400000e+01
  9.500000e+01  9.600000e+01  9.700000e+01  9.800000e+01  9.900000e+01]

Function value at minimum: 5.7412694e-29
k:  2   #<-- stops early if gradient norm is less than tol!!
```

*NOTE*: Examples has the quadratic function and the Rosenbrock function.\
1000-dimensional Rosenbrock solved in 4038 steps


## Installation


            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "GradientTransformation",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.10",
    "maintainer_email": null,
    "keywords": "jax, optimizer, lbfgs, optimization",
    "author": null,
    "author_email": "Joseph Schafer <joeschafer28@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/8d/a5/e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884/gradienttransformation-1.0.0.tar.gz",
    "platform": null,
    "description": "# L-BFGS optimizer written with JAX\n\n## Features\n\n- Implements the Limited-memory [BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm.\n- JIT/vmap/pmap compatible for performance with JAX.\n- Note requirements.txt is setup for JAX[CPU]\n\n## Usage\nDefine a function to minimize\n```python\ndef func(x): \n    jnp.sum((-1*coefficients + x)**2)\n```\n\nCall Lbfgs\n-f: function to minimize\n-m: number of previous iterations to store in memory\n-tol: tolerance of convergence\n```python\noptimizer = Lbfgs(f=func, m=10, tol=1e-6)\n```\n\niterate to find minimum\n```python\n# Initialize optimizer state\nopt_state = optimizer.init(x0)\n\n@jax.jit\ndef opt_step(carry, _):\n    opt_state, losses = carry\n    opt_state = optimizer.update(opt_state)\n    losses = losses.at[opt_state.k].set(loss(opt_state.position))\n    return (opt_state, losses), _\n\niterations=10000   #<-- A lot of iterations!!!\nlosses = jnp.zeros((iterations,))\n(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)\n#note losses will be the length of iterations\nlosses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))\n```\n\noutput\n```\n[-7.577116e-15  1.000000e+00  2.000000e+00  3.000000e+00  4.000000e+00\n  5.000000e+00  6.000000e+00  7.000000e+00  8.000000e+00  9.000000e+00\n  1.000000e+01  1.100000e+01  1.200000e+01  1.300000e+01  1.400000e+01\n  1.500000e+01  1.600000e+01  1.700000e+01  1.800000e+01  1.900000e+01\n  2.000000e+01  2.100000e+01  2.200000e+01  2.300000e+01  2.400000e+01\n  2.500000e+01  2.600000e+01  2.700000e+01  2.800000e+01  2.900000e+01\n  3.000000e+01  3.100000e+01  3.200000e+01  3.300000e+01  3.400000e+01\n  3.500000e+01  3.600000e+01  3.700000e+01  3.800000e+01  3.900000e+01\n  4.000000e+01  4.100000e+01  4.200000e+01  4.300000e+01  4.400000e+01\n  4.500000e+01  4.600000e+01  4.700000e+01  4.800000e+01  4.900000e+01\n  5.000000e+01  5.100000e+01  5.200000e+01  5.300000e+01  5.400000e+01\n  5.500000e+01  5.600000e+01  5.700000e+01  5.800000e+01  5.900000e+01\n  6.000000e+01  6.100000e+01  6.200000e+01  6.300000e+01  6.400000e+01\n  6.500000e+01  6.600000e+01  6.700000e+01  6.800000e+01  6.900000e+01\n  7.000000e+01  7.100000e+01  7.200000e+01  7.300000e+01  7.400000e+01\n  7.500000e+01  7.600000e+01  7.700000e+01  7.800000e+01  7.900000e+01\n  8.000000e+01  8.100000e+01  8.200000e+01  8.300000e+01  8.400000e+01\n  8.500000e+01  8.600000e+01  8.700000e+01  8.800000e+01  8.900000e+01\n  9.000000e+01  9.100000e+01  9.200000e+01  9.300000e+01  9.400000e+01\n  9.500000e+01  9.600000e+01  9.700000e+01  9.800000e+01  9.900000e+01]\n\nFunction value at minimum: 5.7412694e-29\nk:  2   #<-- stops early if gradient norm is less than tol!!\n```\n\n*NOTE*: Examples has the quadratic function and the Rosenbrock function.\\\n1000-dimensional Rosenbrock solved in 4038 steps\n\n\n## Installation\n\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "A JAX-based L-BFGS optimizer",
    "version": "1.0.0",
    "project_urls": {
        "Repository": "https://github.com/JWSch4fer/LBFGS_JAX"
    },
    "split_keywords": [
        "jax",
        " optimizer",
        " lbfgs",
        " optimization"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "9864ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5",
                "md5": "c1d3ecc4f8d05a2a6d8a17dd11cfebcd",
                "sha256": "876b80629cbf944204b7d51d7e3ec57c4b558b270f373f079c0327de34634f90"
            },
            "downloads": -1,
            "filename": "GradientTransformation-1.0.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "c1d3ecc4f8d05a2a6d8a17dd11cfebcd",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10",
            "size": 7732,
            "upload_time": "2024-10-09T17:42:54",
            "upload_time_iso_8601": "2024-10-09T17:42:54.585696Z",
            "url": "https://files.pythonhosted.org/packages/98/64/ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5/GradientTransformation-1.0.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8da5e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884",
                "md5": "18b5d27181eff6a712525de6ea3de8e4",
                "sha256": "1d1dde18a26d6f44f6090021d704c39a92bf06a44ef87bb9f32aac2b52ba1eef"
            },
            "downloads": -1,
            "filename": "gradienttransformation-1.0.0.tar.gz",
            "has_sig": false,
            "md5_digest": "18b5d27181eff6a712525de6ea3de8e4",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10",
            "size": 6954,
            "upload_time": "2024-10-09T17:42:56",
            "upload_time_iso_8601": "2024-10-09T17:42:56.300843Z",
            "url": "https://files.pythonhosted.org/packages/8d/a5/e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884/gradienttransformation-1.0.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-09 17:42:56",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "JWSch4fer",
    "github_project": "LBFGS_JAX",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [],
    "lcname": "gradienttransformation"
}
        
Elapsed time: 0.52643s