agjax


Nameagjax JSON
Version 0.3.0 PyPI version JSON
download
home_page
SummaryA jax wrapper for autograd-differentiable functions.
upload_time2024-02-21 04:04:19
maintainer
docs_urlNone
author
requires_python>=3.7
licenseThe MIT License (MIT) Copyright (c) 2023 Martin Schubert Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
keywords autograd jax python wrapper gradient
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Agjax -- jax wrapper for autograd-differentiable functions.
`v0.3.0`

Agjax allows existing code built with autograd to be used with the jax framework.

In particular, `agjax.wrap_for_jax` allows arbitrary autograd functions ot be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.

Meanwhile, `agjax.experimental.wrap_for_jax` supports `grad`, `jit`, `vmap`, and `jacrev`. However, it depends on certain under-the-hood behavior by jax, which is not guaranteed to remain unchanged. It also is more restrictive in terms of the valid function signatures of functions to be wrapped: all arguments and outputs must be convertible to valid jax types. (`agjax.wrap_for_jax` also supports non-jax inputs and outputs, e.g. strings.)

## Installation
```
pip install agjax
```

## Usage
Basic usage is as follows:
```python
@agjax.wrap_for_jax
def fn(x, y):
  return x * npa.cos(y)

jax.grad(fn, argnums=(0,  1))(1.0, 0.0)

# (Array(1., dtype=float32), Array(0., dtype=float32))
```

The experimental wrapper is similar, but requires that the function outputs and datatypes be specified, simiilar to `jax.pure_callback`.
```python
wrapped_fn = agjax.experimental.wrap_for_jax(
  lambda x, y: x * npa.cos(y),
  result_shape_dtypes=jnp.ones((5,)),
)

jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))

# [[ 0.28366217  0.          0.          0.          0.        ]
#  [ 0.          0.96017027  0.          0.          0.        ]
#  [ 0.          0.          0.75390226  0.          0.        ]
#  [ 0.          0.          0.         -0.14550003  0.        ]
#  [ 0.          0.          0.          0.         -0.91113025]]
```

Agjax wrappers are intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments. In the experimental wrapper, these must still be jax-convertible types, while in the standard wrapper they may have arbitrary typess.

```python
@functools.partial(
  agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)
)
def fn(x, y, string_arg):
  return x * npa.cos(y), string_arg * 2

(value, aux), grad = jax.value_and_grad(
  fn, argnums=(0, 1), has_aux=True
)(1.0, 0.0, "test")

print(f"value = {value}")
print(f"  aux = {aux}")
print(f" grad = {grad}")
```
```
value = 1.0
  aux = testtest
 grad = (Array(1., dtype=float32), Array(0., dtype=float32))
```

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "agjax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "\"Martin F. Schubert\" <mfschubert@gmail.com>",
    "keywords": "autograd,jax,python,wrapper,gradient",
    "author": "",
    "author_email": "\"Martin F. Schubert\" <mfschubert@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/02/d2/fa61c29b3834400d4e235fe6c50908d74018873919df385a3ee9327eb37f/agjax-0.3.0.tar.gz",
    "platform": null,
    "description": "# Agjax -- jax wrapper for autograd-differentiable functions.\n`v0.3.0`\n\nAgjax allows existing code built with autograd to be used with the jax framework.\n\nIn particular, `agjax.wrap_for_jax` allows arbitrary autograd functions ot be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.\n\nMeanwhile, `agjax.experimental.wrap_for_jax` supports `grad`, `jit`, `vmap`, and `jacrev`. However, it depends on certain under-the-hood behavior by jax, which is not guaranteed to remain unchanged. It also is more restrictive in terms of the valid function signatures of functions to be wrapped: all arguments and outputs must be convertible to valid jax types. (`agjax.wrap_for_jax` also supports non-jax inputs and outputs, e.g. strings.)\n\n## Installation\n```\npip install agjax\n```\n\n## Usage\nBasic usage is as follows:\n```python\n@agjax.wrap_for_jax\ndef fn(x, y):\n  return x * npa.cos(y)\n\njax.grad(fn, argnums=(0,  1))(1.0, 0.0)\n\n# (Array(1., dtype=float32), Array(0., dtype=float32))\n```\n\nThe experimental wrapper is similar, but requires that the function outputs and datatypes be specified, simiilar to `jax.pure_callback`.\n```python\nwrapped_fn = agjax.experimental.wrap_for_jax(\n  lambda x, y: x * npa.cos(y),\n  result_shape_dtypes=jnp.ones((5,)),\n)\n\njax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))\n\n# [[ 0.28366217  0.          0.          0.          0.        ]\n#  [ 0.          0.96017027  0.          0.          0.        ]\n#  [ 0.          0.          0.75390226  0.          0.        ]\n#  [ 0.          0.          0.         -0.14550003  0.        ]\n#  [ 0.          0.          0.          0.         -0.91113025]]\n```\n\nAgjax wrappers are intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments. In the experimental wrapper, these must still be jax-convertible types, while in the standard wrapper they may have arbitrary typess.\n\n```python\n@functools.partial(\n  agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)\n)\ndef fn(x, y, string_arg):\n  return x * npa.cos(y), string_arg * 2\n\n(value, aux), grad = jax.value_and_grad(\n  fn, argnums=(0, 1), has_aux=True\n)(1.0, 0.0, \"test\")\n\nprint(f\"value = {value}\")\nprint(f\"  aux = {aux}\")\nprint(f\" grad = {grad}\")\n```\n```\nvalue = 1.0\n  aux = testtest\n grad = (Array(1., dtype=float32), Array(0., dtype=float32))\n```\n",
    "bugtrack_url": null,
    "license": "The MIT License (MIT)  Copyright (c) 2023 Martin Schubert  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:  The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.  THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ",
    "summary": "A jax wrapper for autograd-differentiable functions.",
    "version": "0.3.0",
    "project_urls": null,
    "split_keywords": [
        "autograd",
        "jax",
        "python",
        "wrapper",
        "gradient"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a622bead09822560b21c937b226cab3e9eb42de8733c10b0a9d47810e9d46b87",
                "md5": "18db07654114c53eb2caf2b3b8a9d268",
                "sha256": "9f365b48b0fb1442b1bf70077303ec5cf88e083d22dfe5bb2139894c3e72a57b"
            },
            "downloads": -1,
            "filename": "agjax-0.3.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "18db07654114c53eb2caf2b3b8a9d268",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 9671,
            "upload_time": "2024-02-21T04:04:17",
            "upload_time_iso_8601": "2024-02-21T04:04:17.288929Z",
            "url": "https://files.pythonhosted.org/packages/a6/22/bead09822560b21c937b226cab3e9eb42de8733c10b0a9d47810e9d46b87/agjax-0.3.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "02d2fa61c29b3834400d4e235fe6c50908d74018873919df385a3ee9327eb37f",
                "md5": "2a377cff3c781d6cbed428e55a93f565",
                "sha256": "8d0cae3c9be22c9c547b809b33b1e380e865726cc2ab79ea21f95c0b22963bbb"
            },
            "downloads": -1,
            "filename": "agjax-0.3.0.tar.gz",
            "has_sig": false,
            "md5_digest": "2a377cff3c781d6cbed428e55a93f565",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 11198,
            "upload_time": "2024-02-21T04:04:19",
            "upload_time_iso_8601": "2024-02-21T04:04:19.093388Z",
            "url": "https://files.pythonhosted.org/packages/02/d2/fa61c29b3834400d4e235fe6c50908d74018873919df385a3ee9327eb37f/agjax-0.3.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-21 04:04:19",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "agjax"
}
        
Elapsed time: 0.20399s