inox


Nameinox JSON
Version 0.6.3 PyPI version JSON
download
home_pageNone
SummaryStainless neural networks in JAX
upload_time2024-09-02 10:25:38
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords jax pytree neural networks deep learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            ![Inox's banner](https://raw.githubusercontent.com/francois-rozet/inox/master/docs/images/banner.svg)

# Stainless neural networks in JAX

Inox is a minimal [JAX](https://github.com/google/jax) library for neural networks with an intuitive [PyTorch](https://github.com/pytorch/pytorch)-like syntax. As with [Equinox](https://github.com/patrick-kidger/equinox), modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like [NNX](https://github.com/cgarciae/nnx) and [Serket](https://github.com/ASEM000/serket) to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.

> Inox means "stainless steel" in French 🔪

## Installation

The `inox` package is available on [PyPI](https://pypi.org/project/inox), which means it is installable via `pip`.

```
pip install inox
```

Alternatively, if you need the latest features, you can install it from the repository.

```
pip install git+https://github.com/francois-rozet/inox
```

## Getting started

Modules are defined with an intuitive PyTorch-like syntax,

```python
import jax
import inox.nn as nn

init_key, data_key = jax.random.split(jax.random.key(0))

class MLP(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(3, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, 3, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        return x

model = MLP(init_key)
```

and are compatible with JAX transformations.

```python
X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)

@jax.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(model, X, Y)
```

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

```python
model.name = 'stainless'  # not an array

@inox.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = inox.grad(loss_fn)(model, X, Y)
```

Inox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.

```python
model.mask = jax.numpy.array([1, 0, 1])  # not a parameter

static, params, others = model.partition(nn.Parameter)

@jax.jit
def loss_fn(params, others, x, y):
    model = static(arrays, others)
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(params, others, X, Y)
params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)

model = static(params, others)
```

For more information, check out the documentation and tutorials at [inox.readthedocs.io](https://inox.readthedocs.io).

## Contributing

If you have a question, an issue or would like to contribute, please read our [contributing guidelines](https://github.com/francois-rozet/inox/blob/master/CONTRIBUTING.md).

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "inox",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "jax, pytree, neural networks, deep learning",
    "author": null,
    "author_email": "Fran\u00e7ois Rozet <francois.rozet@outlook.com>",
    "download_url": "https://files.pythonhosted.org/packages/4f/ad/e2d4181f609c4633b220fca378eb820ad7c769b8f5c79d11fe81c2dc1383/inox-0.6.3.tar.gz",
    "platform": null,
    "description": "![Inox's banner](https://raw.githubusercontent.com/francois-rozet/inox/master/docs/images/banner.svg)\n\n# Stainless neural networks in JAX\n\nInox is a minimal [JAX](https://github.com/google/jax) library for neural networks with an intuitive [PyTorch](https://github.com/pytorch/pytorch)-like syntax. As with [Equinox](https://github.com/patrick-kidger/equinox), modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.\n\nInox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like [NNX](https://github.com/cgarciae/nnx) and [Serket](https://github.com/ASEM000/serket) to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.\n\n> Inox means \"stainless steel\" in French \ud83d\udd2a\n\n## Installation\n\nThe `inox` package is available on [PyPI](https://pypi.org/project/inox), which means it is installable via `pip`.\n\n```\npip install inox\n```\n\nAlternatively, if you need the latest features, you can install it from the repository.\n\n```\npip install git+https://github.com/francois-rozet/inox\n```\n\n## Getting started\n\nModules are defined with an intuitive PyTorch-like syntax,\n\n```python\nimport jax\nimport inox.nn as nn\n\ninit_key, data_key = jax.random.split(jax.random.key(0))\n\nclass MLP(nn.Module):\n    def __init__(self, key):\n        keys = jax.random.split(key, 3)\n\n        self.l1 = nn.Linear(3, 64, key=keys[0])\n        self.l2 = nn.Linear(64, 64, key=keys[1])\n        self.l3 = nn.Linear(64, 3, key=keys[2])\n        self.relu = nn.ReLU()\n\n    def __call__(self, x):\n        x = self.l1(x)\n        x = self.l2(self.relu(x))\n        x = self.l3(self.relu(x))\n\n        return x\n\nmodel = MLP(init_key)\n```\n\nand are compatible with JAX transformations.\n\n```python\nX = jax.random.normal(data_key, (1024, 3))\nY = jax.numpy.sort(X, axis=-1)\n\n@jax.jit\ndef loss_fn(model, x, y):\n    pred = jax.vmap(model)(x)\n    return jax.numpy.mean((y - pred) ** 2)\n\ngrads = jax.grad(loss_fn)(model, X, Y)\n```\n\nHowever, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.\n\n```python\nmodel.name = 'stainless'  # not an array\n\n@inox.jit\ndef loss_fn(model, x, y):\n    pred = jax.vmap(model)(x)\n    return jax.numpy.mean((y - pred) ** 2)\n\ngrads = inox.grad(loss_fn)(model, X, Y)\n```\n\nInox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.\n\n```python\nmodel.mask = jax.numpy.array([1, 0, 1])  # not a parameter\n\nstatic, params, others = model.partition(nn.Parameter)\n\n@jax.jit\ndef loss_fn(params, others, x, y):\n    model = static(arrays, others)\n    pred = jax.vmap(model)(x)\n    return jax.numpy.mean((y - pred) ** 2)\n\ngrads = jax.grad(loss_fn)(params, others, X, Y)\nparams = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)\n\nmodel = static(params, others)\n```\n\nFor more information, check out the documentation and tutorials at [inox.readthedocs.io](https://inox.readthedocs.io).\n\n## Contributing\n\nIf you have a question, an issue or would like to contribute, please read our [contributing guidelines](https://github.com/francois-rozet/inox/blob/master/CONTRIBUTING.md).\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Stainless neural networks in JAX",
    "version": "0.6.3",
    "project_urls": {
        "documentation": "https://inox.readthedocs.io",
        "source": "https://github.com/francois-rozet/inox",
        "tracker": "https://github.com/francois-rozet/inox/issues"
    },
    "split_keywords": [
        "jax",
        " pytree",
        " neural networks",
        " deep learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "cb64783c4d5d8f51fc7ae53bdf44c6c68bdd2fcd7211b5dbf17834c663fb2e8d",
                "md5": "32f09accf963a9eb586e74991f8f19ad",
                "sha256": "255163df6ac9ee958a516f3a8f9c6552024733ee1445dc1c2c613b09c4bf9997"
            },
            "downloads": -1,
            "filename": "inox-0.6.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "32f09accf963a9eb586e74991f8f19ad",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 35546,
            "upload_time": "2024-09-02T10:25:36",
            "upload_time_iso_8601": "2024-09-02T10:25:36.829694Z",
            "url": "https://files.pythonhosted.org/packages/cb/64/783c4d5d8f51fc7ae53bdf44c6c68bdd2fcd7211b5dbf17834c663fb2e8d/inox-0.6.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4fade2d4181f609c4633b220fca378eb820ad7c769b8f5c79d11fe81c2dc1383",
                "md5": "dc696489822b7dfc7c3fac0c5a5e5b1c",
                "sha256": "786b5bc5f25fa9260d5f1fee17dde077dcfaf9d5fe947c1d85d56fc6f8c1f314"
            },
            "downloads": -1,
            "filename": "inox-0.6.3.tar.gz",
            "has_sig": false,
            "md5_digest": "dc696489822b7dfc7c3fac0c5a5e5b1c",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 30803,
            "upload_time": "2024-09-02T10:25:38",
            "upload_time_iso_8601": "2024-09-02T10:25:38.420602Z",
            "url": "https://files.pythonhosted.org/packages/4f/ad/e2d4181f609c4633b220fca378eb820ad7c769b8f5c79d11fe81c2dc1383/inox-0.6.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-02 10:25:38",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "francois-rozet",
    "github_project": "inox",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "inox"
}
        
Elapsed time: 0.31801s