jax-optix


Namejax-optix JSON
Version 0.1.8 PyPI version JSON
download
home_pageNone
SummaryZero-overhead functional lensing for JAX PyTrees
upload_time2025-02-13 16:06:02
maintainerNone
docs_urlNone
authorNone
requires_python>=3.12
licenseMIT
keywords equinox functional jax lens optics
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

## Features

- Type-safe lenses for any JAX PyTree structure
- Zero runtime overhead (generates identical HLO code)
- Intuitive API for accessing and modifying nested values
- Complete static typing support

## Example

```python
from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )
```

## Installation

```bash
pip install jax-optix
```

## License

MIT License

## Credits

Special thanks to [Patrick Kidger](https://kidger.site/) for providing helpful hints and the [Equinox](https://github.com/patrick-kidger/equinox) library, which this project builds upon.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "jax-optix",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.12",
    "maintainer_email": null,
    "keywords": "equinox, functional, jax, lens, optics",
    "author": null,
    "author_email": "Jonas K\u00f6hler <jonas.koehler.ks@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/49/6d/898a30442b8ab13d2a9ee6c7aed52abb51c86bdf55e7f16b4c42048665ca/jax_optix-0.1.8.tar.gz",
    "platform": null,
    "description": "# Optix \ud83d\udd0d\n\nA functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.\n\n## Features\n\n- Type-safe lenses for any JAX PyTree structure\n- Zero runtime overhead (generates identical HLO code)\n- Intuitive API for accessing and modifying nested values\n- Complete static typing support\n\n## Example\n\n```python\nfrom optix import focus\nimport jax.numpy as jnp\n\n# Create a nested PyTree structure\ndata = MyStruct(\n    x=jnp.array([1.0, 2.0]),\n    nested=NestedStruct(y=jnp.array(3.0))\n)\n\n# Focus on and modify a nested value\nresult = focus(data).at(lambda x: x.nested.y).apply(jnp.square)\n>>> MyStruct(\n>>>     x=Array([1., 2.], dtype=float32),\n>>>     nested=NestedStruct(\n>>>         y=Array(9., dtype=float32)\n>>>     )\n>>> )\n```\n\n## Installation\n\n```bash\npip install jax-optix\n```\n\n## License\n\nMIT License\n\n## Credits\n\nSpecial thanks to [Patrick Kidger](https://kidger.site/) for providing helpful hints and the [Equinox](https://github.com/patrick-kidger/equinox) library, which this project builds upon.\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Zero-overhead functional lensing for JAX PyTrees",
    "version": "0.1.8",
    "project_urls": {
        "Homepage": "https://github.com/jonkhler/optix"
    },
    "split_keywords": [
        "equinox",
        " functional",
        " jax",
        " lens",
        " optics"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "ca6a9590cd31ef28ff7eae7228613e2489774e91dff889c559e97cf851e9bf58",
                "md5": "82271c6bcd8103ed03c2ec44ef98b246",
                "sha256": "3a2a191479a3114a1995252ffc97c4ac16d0b6e7cd0b6ee9457687015321b3a6"
            },
            "downloads": -1,
            "filename": "jax_optix-0.1.8-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "82271c6bcd8103ed03c2ec44ef98b246",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.12",
            "size": 3692,
            "upload_time": "2025-02-13T16:06:01",
            "upload_time_iso_8601": "2025-02-13T16:06:01.482588Z",
            "url": "https://files.pythonhosted.org/packages/ca/6a/9590cd31ef28ff7eae7228613e2489774e91dff889c559e97cf851e9bf58/jax_optix-0.1.8-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "496d898a30442b8ab13d2a9ee6c7aed52abb51c86bdf55e7f16b4c42048665ca",
                "md5": "af3076d42d906a6716166e4a24d886d8",
                "sha256": "48e55e0faa986b934365331c97b1039321940ae339e1dda9126f2c3f2fef2a7a"
            },
            "downloads": -1,
            "filename": "jax_optix-0.1.8.tar.gz",
            "has_sig": false,
            "md5_digest": "af3076d42d906a6716166e4a24d886d8",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.12",
            "size": 3509,
            "upload_time": "2025-02-13T16:06:02",
            "upload_time_iso_8601": "2025-02-13T16:06:02.758860Z",
            "url": "https://files.pythonhosted.org/packages/49/6d/898a30442b8ab13d2a9ee6c7aed52abb51c86bdf55e7f16b4c42048665ca/jax_optix-0.1.8.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-02-13 16:06:02",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "jonkhler",
    "github_project": "optix",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "jax-optix"
}
        
Elapsed time: 9.36692s