agjax


Nameagjax JSON
Version 0.3.7 PyPI version JSON
download
home_pageNone
SummaryA jax wrapper for autograd-differentiable functions.
upload_time2025-10-20 18:44:55
maintainerNone
docs_urlNone
authorNone
requires_python>=3.7
licenseNone
keywords autograd jax python wrapper gradient
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Agjax - jax wrapper for autograd-differentiable functions.
[![Docs](https://img.shields.io/badge/Docs-blue.svg)](https://invrs-io.github.io/agjax/)
[![Continuous integration](https://github.com/invrs-io/agjax/actions/workflows/build-ci.yml/badge.svg)](https://github.com/invrs-io/agjax/actions)
[![PyPI version](https://img.shields.io/pypi/v/agjax)](https://pypi.org/project/agjax/)

Agjax allows existing code built with autograd to be used with the jax framework.

In particular, `agjax.wrap_for_jax` allows arbitrary autograd functions ot be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.

Meanwhile, `agjax.experimental.wrap_for_jax` supports `grad`, `jit`, `vmap`, and `jacrev`. However, it depends on certain under-the-hood behavior by jax, which is not guaranteed to remain unchanged. It also is more restrictive in terms of the valid function signatures of functions to be wrapped: all arguments and outputs must be convertible to valid jax types. (`agjax.wrap_for_jax` also supports non-jax inputs and outputs, e.g. strings.)

## Installation
```
pip install agjax
```

## Usage
Basic usage is as follows:
```python
@agjax.wrap_for_jax
def fn(x, y):
  return x * npa.cos(y)

jax.grad(fn, argnums=(0,  1))(1.0, 0.0)

# (Array(1., dtype=float32), Array(0., dtype=float32))
```

The experimental wrapper is similar, but requires that the function outputs and datatypes be specified, simiilar to `jax.pure_callback`.
```python
wrapped_fn = agjax.experimental.wrap_for_jax(
  lambda x, y: x * npa.cos(y),
  result_shape_dtypes=jnp.ones((5,)),
)

jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))

# [[ 0.28366217  0.          0.          0.          0.        ]
#  [ 0.          0.96017027  0.          0.          0.        ]
#  [ 0.          0.          0.75390226  0.          0.        ]
#  [ 0.          0.          0.         -0.14550003  0.        ]
#  [ 0.          0.          0.          0.         -0.91113025]]
```

Agjax wrappers are intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments. In the experimental wrapper, these must still be jax-convertible types, while in the standard wrapper they may have arbitrary types.

```python
@functools.partial(
  agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)
)
def fn(x, y, string_arg):
  return x * npa.cos(y), string_arg * 2

(value, aux), grad = jax.value_and_grad(
  fn, argnums=(0, 1), has_aux=True
)(1.0, 0.0, "test")

print(f"value = {value}")
print(f"  aux = {aux}")
print(f" grad = {grad}")
```
```
value = 1.0
  aux = testtest
 grad = (Array(1., dtype=float32), Array(0., dtype=float32))
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "agjax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "\"Martin F. Schubert\" <mfschubert@gmail.com>",
    "keywords": "autograd, jax, python, wrapper, gradient",
    "author": null,
    "author_email": "\"Martin F. Schubert\" <mfschubert@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/f6/c1/d18527d79dc55332d0b50c37942453a5ee1ac29bc1256da00aa77a94ce82/agjax-0.3.7.tar.gz",
    "platform": null,
    "description": "# Agjax - jax wrapper for autograd-differentiable functions.\n[![Docs](https://img.shields.io/badge/Docs-blue.svg)](https://invrs-io.github.io/agjax/)\n[![Continuous integration](https://github.com/invrs-io/agjax/actions/workflows/build-ci.yml/badge.svg)](https://github.com/invrs-io/agjax/actions)\n[![PyPI version](https://img.shields.io/pypi/v/agjax)](https://pypi.org/project/agjax/)\n\nAgjax allows existing code built with autograd to be used with the jax framework.\n\nIn particular, `agjax.wrap_for_jax` allows arbitrary autograd functions ot be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.\n\nMeanwhile, `agjax.experimental.wrap_for_jax` supports `grad`, `jit`, `vmap`, and `jacrev`. However, it depends on certain under-the-hood behavior by jax, which is not guaranteed to remain unchanged. It also is more restrictive in terms of the valid function signatures of functions to be wrapped: all arguments and outputs must be convertible to valid jax types. (`agjax.wrap_for_jax` also supports non-jax inputs and outputs, e.g. strings.)\n\n## Installation\n```\npip install agjax\n```\n\n## Usage\nBasic usage is as follows:\n```python\n@agjax.wrap_for_jax\ndef fn(x, y):\n  return x * npa.cos(y)\n\njax.grad(fn, argnums=(0,  1))(1.0, 0.0)\n\n# (Array(1., dtype=float32), Array(0., dtype=float32))\n```\n\nThe experimental wrapper is similar, but requires that the function outputs and datatypes be specified, simiilar to `jax.pure_callback`.\n```python\nwrapped_fn = agjax.experimental.wrap_for_jax(\n  lambda x, y: x * npa.cos(y),\n  result_shape_dtypes=jnp.ones((5,)),\n)\n\njax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))\n\n# [[ 0.28366217  0.          0.          0.          0.        ]\n#  [ 0.          0.96017027  0.          0.          0.        ]\n#  [ 0.          0.          0.75390226  0.          0.        ]\n#  [ 0.          0.          0.         -0.14550003  0.        ]\n#  [ 0.          0.          0.          0.         -0.91113025]]\n```\n\nAgjax wrappers are intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments. In the experimental wrapper, these must still be jax-convertible types, while in the standard wrapper they may have arbitrary types.\n\n```python\n@functools.partial(\n  agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)\n)\ndef fn(x, y, string_arg):\n  return x * npa.cos(y), string_arg * 2\n\n(value, aux), grad = jax.value_and_grad(\n  fn, argnums=(0, 1), has_aux=True\n)(1.0, 0.0, \"test\")\n\nprint(f\"value = {value}\")\nprint(f\"  aux = {aux}\")\nprint(f\" grad = {grad}\")\n```\n```\nvalue = 1.0\n  aux = testtest\n grad = (Array(1., dtype=float32), Array(0., dtype=float32))\n```\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A jax wrapper for autograd-differentiable functions.",
    "version": "0.3.7",
    "project_urls": null,
    "split_keywords": [
        "autograd",
        " jax",
        " python",
        " wrapper",
        " gradient"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "d3657be6f353c714271068f896e68e23940d163eb3f20b423a25fb22075447d8",
                "md5": "39cd3559f5df048981097cd5b6687fe3",
                "sha256": "eb6eb0c5222ec3f2210df60f7b806c5222f8ad3f1d4a98c9ccbdfdf3b801f57c"
            },
            "downloads": -1,
            "filename": "agjax-0.3.7-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "39cd3559f5df048981097cd5b6687fe3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 9755,
            "upload_time": "2025-10-20T18:44:55",
            "upload_time_iso_8601": "2025-10-20T18:44:55.106416Z",
            "url": "https://files.pythonhosted.org/packages/d3/65/7be6f353c714271068f896e68e23940d163eb3f20b423a25fb22075447d8/agjax-0.3.7-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "f6c1d18527d79dc55332d0b50c37942453a5ee1ac29bc1256da00aa77a94ce82",
                "md5": "96256ec3a5ab93df6bfde96dfed75c96",
                "sha256": "89f0060e964896319863bb67deaa34423a2602536b209f110685220c122ffb53"
            },
            "downloads": -1,
            "filename": "agjax-0.3.7.tar.gz",
            "has_sig": false,
            "md5_digest": "96256ec3a5ab93df6bfde96dfed75c96",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 11308,
            "upload_time": "2025-10-20T18:44:55",
            "upload_time_iso_8601": "2025-10-20T18:44:55.868158Z",
            "url": "https://files.pythonhosted.org/packages/f6/c1/d18527d79dc55332d0b50c37942453a5ee1ac29bc1256da00aa77a94ce82/agjax-0.3.7.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-10-20 18:44:55",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "agjax"
}
        
Elapsed time: 1.73225s