FixedPointJAX


NameFixedPointJAX JSON
Version 0.0.25 PyPI version JSON
download
home_pagehttps://github.com/esbenscriver/FixedPointJAX.git
SummaryFixed-point iterations for root finding implemented in JAX
upload_time2024-09-06 13:09:49
maintainerNone
docs_urlNone
authorEsben Scriver Andersen
requires_python>=3.6
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Fixed-point solver
FixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX.

* Strives to be minimal
* Has no dependencies other than JAX

## Installation

```bash
pip install FixedPointJAX
```

## Usage

```python

import jax.numpy as jnp
from jax import random

from FixedPointJAX import FixedPointRoot

# Define the logit probabilities
def my_logit(x, axis=0):
	nominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
	denominator = jnp.sum(nominator, axis=axis, keepdims=True)
	return nominator / denominator
	
# Define the function for the fixed-point iteration
def my_fxp(x,s0):
	s = my_logit(x)
	z = jnp.log(s0 / s)
	return x + z, z
print('-----------------------------------------')
# Dimensions of system of fixed-point equations
shape = (3, 4)

# Simulate probabilities
s0 = my_logit(random.uniform(key=random.PRNGKey(123), shape=shape))

# Set up fixed-point equation
fun = lambda x: my_fxp(x,s0)

# Initial guess
x0 = jnp.zeros_like(s0)

# Solve the fixed-point equation
x, (step_norm, root_norm, iterations) = FixedPointRoot(fun, x0)
print('-----------------------------------------')
print(f'System of fixed-point equations is solved: {jnp.allclose(x,fun(x)[0])}.')
print(f'Probabilities are identical: {jnp.allclose(s0, my_logit(x))}.')
print('-----------------------------------------')
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/esbenscriver/FixedPointJAX.git",
    "name": "FixedPointJAX",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.6",
    "maintainer_email": null,
    "keywords": null,
    "author": "Esben Scriver Andersen",
    "author_email": "esbenscriver@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/d2/dc/b2849c82991d9a149d2de0bc8106d2e807bd108404e9a635bcaa24fc0ca1/FixedPointJAX-0.0.25.tar.gz",
    "platform": null,
    "description": "# Fixed-point solver\nFixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX.\n\n* Strives to be minimal\n* Has no dependencies other than JAX\n\n## Installation\n\n```bash\npip install FixedPointJAX\n```\n\n## Usage\n\n```python\n\nimport jax.numpy as jnp\nfrom jax import random\n\nfrom FixedPointJAX import FixedPointRoot\n\n# Define the logit probabilities\ndef my_logit(x, axis=0):\n\tnominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))\n\tdenominator = jnp.sum(nominator, axis=axis, keepdims=True)\n\treturn nominator / denominator\n\t\n# Define the function for the fixed-point iteration\ndef my_fxp(x,s0):\n\ts = my_logit(x)\n\tz = jnp.log(s0 / s)\n\treturn x + z, z\nprint('-----------------------------------------')\n# Dimensions of system of fixed-point equations\nshape = (3, 4)\n\n# Simulate probabilities\ns0 = my_logit(random.uniform(key=random.PRNGKey(123), shape=shape))\n\n# Set up fixed-point equation\nfun = lambda x: my_fxp(x,s0)\n\n# Initial guess\nx0 = jnp.zeros_like(s0)\n\n# Solve the fixed-point equation\nx, (step_norm, root_norm, iterations) = FixedPointRoot(fun, x0)\nprint('-----------------------------------------')\nprint(f'System of fixed-point equations is solved: {jnp.allclose(x,fun(x)[0])}.')\nprint(f'Probabilities are identical: {jnp.allclose(s0, my_logit(x))}.')\nprint('-----------------------------------------')\n```\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Fixed-point iterations for root finding implemented in JAX",
    "version": "0.0.25",
    "project_urls": {
        "Homepage": "https://github.com/esbenscriver/FixedPointJAX.git"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4df5ec6da53b3df12da7aa5d75ca67df0b6631e298680fcdf6e107e99c0d0932",
                "md5": "38e02299a888649f2be0beb2dba14e82",
                "sha256": "1747538b54ed6f15dc4878a1a3ab48f75a890b9bd9b0b449f7f6f152d2087357"
            },
            "downloads": -1,
            "filename": "FixedPointJAX-0.0.25-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "38e02299a888649f2be0beb2dba14e82",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6",
            "size": 5325,
            "upload_time": "2024-09-06T13:09:47",
            "upload_time_iso_8601": "2024-09-06T13:09:47.779655Z",
            "url": "https://files.pythonhosted.org/packages/4d/f5/ec6da53b3df12da7aa5d75ca67df0b6631e298680fcdf6e107e99c0d0932/FixedPointJAX-0.0.25-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "d2dcb2849c82991d9a149d2de0bc8106d2e807bd108404e9a635bcaa24fc0ca1",
                "md5": "7ab29e2170b8d32a5f431b9de502c843",
                "sha256": "12d5a9b5ed9d0b5476d610e299674d11cd9f5e8bc37c663b700be7ac3d539146"
            },
            "downloads": -1,
            "filename": "FixedPointJAX-0.0.25.tar.gz",
            "has_sig": false,
            "md5_digest": "7ab29e2170b8d32a5f431b9de502c843",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6",
            "size": 4195,
            "upload_time": "2024-09-06T13:09:49",
            "upload_time_iso_8601": "2024-09-06T13:09:49.440491Z",
            "url": "https://files.pythonhosted.org/packages/d2/dc/b2849c82991d9a149d2de0bc8106d2e807bd108404e9a635bcaa24fc0ca1/FixedPointJAX-0.0.25.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-06 13:09:49",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "esbenscriver",
    "github_project": "FixedPointJAX",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "fixedpointjax"
}
        
Elapsed time: 0.47429s