xarray-jax


Namexarray-jax JSON
Version 0.0.5 PyPI version JSON
download
home_pageNone
SummaryNone
upload_time2024-09-30 21:17:11
maintainerNone
docs_urlNone
authorAllen Wang
requires_python<4.0,>=3.10
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Simple Xarray + JAX Integration

This is an experiment at integrating Xarray + JAX in a simple way, leveraging [equinox](https://github.com/patrick-kidger/equinox).

``` python
import jax.numpy as jnp
import xarray as xr
import xarray_jax as xj

# Construct a DataArray.
da = xr.DataArray(
    xr.Variable(["x", "y"], jnp.ones((2, 3))),
    coords={"x": [1, 2], "y": [3, 4, 5]},
    name="foo",
    attrs={"attr1": "value1"},
)

# Do some operations inside a JIT compiled function.
@eqx.filter_jit
def some_function(data):
    neg_data = -1.0 * data
    return neg_data * neg_data.coords["y"] # Multiply data by coords.

da = some_function(da)

# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, data)

# Use jax.grad.
@eqx.filter_jit
def fn(data):
    return (data**2.0).sum().data

grad = jax.grad(fn)(da)

# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)

# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)

```
## Installation
```bash
pip install xarray_jax
```

## Status
- [x] PyTree node registrations
  - [x] `xr.Variable`
  - [x] `xr.DataArray`
  - [x] `xr.Dataset`
- [x] Minimal shadow types implemented as [equinox modules](https://github.com/patrick-kidger/equinox) to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).
  - [x] `XjVariable`
  - [x] `XjDataArray`
  - [x] `XjDataset`
- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types.
- [x] Support for `xr` types with dummy data (useful for tree manipulation).
- [ ] Support for transformations that change the dimensionality of the data.

## Sharp Edges

### Prefer `eqx.filter_jit` over `jax.jit`
There are some edge cases with metadata that `eqx.filter_jit` handles but `jax.jit` does not.

### Operations that Increase the Dimensionality of the Data
Operations that increase the dimensionality of the data (e.g. `jnp.expand_dims`) will cause problems downstream.

``` python
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will not error.
var = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)

# The error from expanding the dimensionality will be triggered here.
var = var + 1 
```

### Dispatching to jnp is not supported yet
Pending resolution of https://github.com/pydata/xarray/issues/7848.
``` python
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will fail.
jnp.square(var)

# This will work.
xr.apply_ufunc(jnp.square, var)
```


## Distinction from the GraphCast Implementation
This experiment is largely inspired by the [GraphCast implementation](https://github.com/google-deepmind/graphcast/blob/main/graphcast/xarray_jax.py), with a direct re-use of the `_HashableCoords` in that project.

However, this experiment aims to:
1. Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).
2. Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.

## Acknowledgements
This repo was made possible by great discussions within the JAX + Xarray open source community, especially [this one](https://github.com/pydata/xarray/discussions/8164). In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "xarray-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.10",
    "maintainer_email": null,
    "keywords": null,
    "author": "Allen Wang",
    "author_email": "allenw@mit.edu",
    "download_url": "https://files.pythonhosted.org/packages/da/a1/aefd27e63c03b0811468b8e979ad17cece1ee2d838b05f4e8f8c910130f2/xarray_jax-0.0.5.tar.gz",
    "platform": null,
    "description": "# Simple Xarray + JAX Integration\n\nThis is an experiment at integrating Xarray + JAX in a simple way, leveraging [equinox](https://github.com/patrick-kidger/equinox).\n\n``` python\nimport jax.numpy as jnp\nimport xarray as xr\nimport xarray_jax as xj\n\n# Construct a DataArray.\nda = xr.DataArray(\n    xr.Variable([\"x\", \"y\"], jnp.ones((2, 3))),\n    coords={\"x\": [1, 2], \"y\": [3, 4, 5]},\n    name=\"foo\",\n    attrs={\"attr1\": \"value1\"},\n)\n\n# Do some operations inside a JIT compiled function.\n@eqx.filter_jit\ndef some_function(data):\n    neg_data = -1.0 * data\n    return neg_data * neg_data.coords[\"y\"] # Multiply data by coords.\n\nda = some_function(da)\n\n# Construct a xr.DataArray with dummy data (useful for tree manipulation).\nda_mask = jax.tree.map(lambda _: True, data)\n\n# Use jax.grad.\n@eqx.filter_jit\ndef fn(data):\n    return (data**2.0).sum().data\n\ngrad = jax.grad(fn)(da)\n\n# Convert to a custom XjDataArray, implemented as an equinox module.\n# (Useful for avoiding potentially weird xarray interactions with JAX).\nxj_da = xj.from_xarray(da)\n\n# Convert back to a xr.DataArray.\nda = xj.to_xarray(xj_da)\n\n```\n## Installation\n```bash\npip install xarray_jax\n```\n\n## Status\n- [x] PyTree node registrations\n  - [x] `xr.Variable`\n  - [x] `xr.DataArray`\n  - [x] `xr.Dataset`\n- [x] Minimal shadow types implemented as [equinox modules](https://github.com/patrick-kidger/equinox) to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).\n  - [x] `XjVariable`\n  - [x] `XjDataArray`\n  - [x] `XjDataset`\n- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types.\n- [x] Support for `xr` types with dummy data (useful for tree manipulation).\n- [ ] Support for transformations that change the dimensionality of the data.\n\n## Sharp Edges\n\n### Prefer `eqx.filter_jit` over `jax.jit`\nThere are some edge cases with metadata that `eqx.filter_jit` handles but `jax.jit` does not.\n\n### Operations that Increase the Dimensionality of the Data\nOperations that increase the dimensionality of the data (e.g. `jnp.expand_dims`) will cause problems downstream.\n\n``` python\nvar = xr.Variable(dims=(\"x\", \"y\"), data=jnp.ones((3, 3)))\n\n# This will not error.\nvar = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)\n\n# The error from expanding the dimensionality will be triggered here.\nvar = var + 1 \n```\n\n### Dispatching to jnp is not supported yet\nPending resolution of https://github.com/pydata/xarray/issues/7848.\n``` python\nvar = xr.Variable(dims=(\"x\", \"y\"), data=jnp.ones((3, 3)))\n\n# This will fail.\njnp.square(var)\n\n# This will work.\nxr.apply_ufunc(jnp.square, var)\n```\n\n\n## Distinction from the GraphCast Implementation\nThis experiment is largely inspired by the [GraphCast implementation](https://github.com/google-deepmind/graphcast/blob/main/graphcast/xarray_jax.py), with a direct re-use of the `_HashableCoords` in that project.\n\nHowever, this experiment aims to:\n1. Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).\n2. Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.\n\n## Acknowledgements\nThis repo was made possible by great discussions within the JAX + Xarray open source community, especially [this one](https://github.com/pydata/xarray/discussions/8164). In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.\n",
    "bugtrack_url": null,
    "license": null,
    "summary": null,
    "version": "0.0.5",
    "project_urls": null,
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4e2ec46e09c47eb6fcfc966413a91eead97080f31bdeb7247f6ef52bc5e2b7b8",
                "md5": "57256e227e84366a1fa2cf2e0d38534f",
                "sha256": "34ac654b2566cc80dc13b3d8ab05b0c9ee00a7aec8d89d688faacdbe07201e75"
            },
            "downloads": -1,
            "filename": "xarray_jax-0.0.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "57256e227e84366a1fa2cf2e0d38534f",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.10",
            "size": 10506,
            "upload_time": "2024-09-30T21:17:09",
            "upload_time_iso_8601": "2024-09-30T21:17:09.709888Z",
            "url": "https://files.pythonhosted.org/packages/4e/2e/c46e09c47eb6fcfc966413a91eead97080f31bdeb7247f6ef52bc5e2b7b8/xarray_jax-0.0.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "daa1aefd27e63c03b0811468b8e979ad17cece1ee2d838b05f4e8f8c910130f2",
                "md5": "b9a30c117050ebe43f59dcb993076572",
                "sha256": "54cf8f9832d5ff50f8798fc385555f4d6cd019a2e416127323606b1f4498f485"
            },
            "downloads": -1,
            "filename": "xarray_jax-0.0.5.tar.gz",
            "has_sig": false,
            "md5_digest": "b9a30c117050ebe43f59dcb993076572",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.10",
            "size": 9419,
            "upload_time": "2024-09-30T21:17:11",
            "upload_time_iso_8601": "2024-09-30T21:17:11.243832Z",
            "url": "https://files.pythonhosted.org/packages/da/a1/aefd27e63c03b0811468b8e979ad17cece1ee2d838b05f4e8f8c910130f2/xarray_jax-0.0.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-30 21:17:11",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "xarray-jax"
}
        
Elapsed time: 2.10313s