![Inox's banner](https://raw.githubusercontent.com/francois-rozet/inox/master/docs/images/banner.svg)
# Stainless neural networks in JAX
Inox is a minimal [JAX](https://github.com/google/jax) library for neural networks with an intuitive PyTorch-like syntax. As with [Equinox](https://github.com/patrick-kidger/equinox), modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.
Inox aims to be a leaner version of [Equinox](https://github.com/patrick-kidger/equinox) by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like [NNX](https://github.com/cgarciae/nnx) and [Serket](https://github.com/ASEM000/serket) to provide a versatile interface.
> Inox means "stainless steel" in French 🔪
## Installation
The `inox` package is available on [PyPI](https://pypi.org/project/inox), which means it is installable via `pip`.
```
pip install inox
```
Alternatively, if you need the latest features, you can install it from the repository.
```
pip install git+https://github.com/francois-rozet/inox
```
## Getting started
Modules are defined with an intuitive PyTorch-like syntax,
```python
import jax
import inox.nn as nn
init_key, data_key = jax.random.split(jax.random.key(0))
class MLP(nn.Module):
def __init__(self, key):
keys = jax.random.split(key, 3)
self.l1 = nn.Linear(3, 64, key=keys[0])
self.l2 = nn.Linear(64, 64, key=keys[1])
self.l3 = nn.Linear(64, 3, key=keys[2])
self.relu = nn.ReLU()
def __call__(self, x):
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
return x
model = MLP(init_key)
```
and are compatible with JAX transformations.
```python
X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)
@jax.jit
def loss_fn(model, x, y):
pred = jax.vmap(model)(x)
return jax.numpy.mean((y - pred) ** 2)
grads = jax.grad(loss_fn)(model, X, Y)
```
However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.
```python
model.name = 'stainless'
@inox.jit
def loss_fn(model, x, y):
pred = inox.vmap(model)(x)
return jax.numpy.mean((y - pred) ** 2)
grads = inox.grad(loss_fn)(model, X, Y)
```
For more information, check out the documentation at [inox.readthedocs.io](https://inox.readthedocs.io).
## Contributing
If you have a question, an issue or would like to contribute, please read our [contributing guidelines](https://github.com/francois-rozet/inox/blob/master/CONTRIBUTING.md).
Raw data
{
"_id": null,
"home_page": null,
"name": "inox",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "jax, pytree, neural networks, deep learning",
"author": null,
"author_email": "Fran\u00e7ois Rozet <francois.rozet@outlook.com>",
"download_url": "https://files.pythonhosted.org/packages/f9/b6/16bacb128da4a24cfcfae54ce2a8625546be71273a84b199577905cb7674/inox-0.5.0.tar.gz",
"platform": null,
"description": "![Inox's banner](https://raw.githubusercontent.com/francois-rozet/inox/master/docs/images/banner.svg)\n\n# Stainless neural networks in JAX\n\nInox is a minimal [JAX](https://github.com/google/jax) library for neural networks with an intuitive PyTorch-like syntax. As with [Equinox](https://github.com/patrick-kidger/equinox), modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.\n\nInox aims to be a leaner version of [Equinox](https://github.com/patrick-kidger/equinox) by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like [NNX](https://github.com/cgarciae/nnx) and [Serket](https://github.com/ASEM000/serket) to provide a versatile interface.\n\n> Inox means \"stainless steel\" in French \ud83d\udd2a\n\n## Installation\n\nThe `inox` package is available on [PyPI](https://pypi.org/project/inox), which means it is installable via `pip`.\n\n```\npip install inox\n```\n\nAlternatively, if you need the latest features, you can install it from the repository.\n\n```\npip install git+https://github.com/francois-rozet/inox\n```\n\n## Getting started\n\nModules are defined with an intuitive PyTorch-like syntax,\n\n```python\nimport jax\nimport inox.nn as nn\n\ninit_key, data_key = jax.random.split(jax.random.key(0))\n\nclass MLP(nn.Module):\n def __init__(self, key):\n keys = jax.random.split(key, 3)\n\n self.l1 = nn.Linear(3, 64, key=keys[0])\n self.l2 = nn.Linear(64, 64, key=keys[1])\n self.l3 = nn.Linear(64, 3, key=keys[2])\n self.relu = nn.ReLU()\n\n def __call__(self, x):\n x = self.l1(x)\n x = self.l2(self.relu(x))\n x = self.l3(self.relu(x))\n\n return x\n\nmodel = MLP(init_key)\n```\n\nand are compatible with JAX transformations.\n\n```python\nX = jax.random.normal(data_key, (1024, 3))\nY = jax.numpy.sort(X, axis=-1)\n\n@jax.jit\ndef loss_fn(model, x, y):\n pred = jax.vmap(model)(x)\n return jax.numpy.mean((y - pred) ** 2)\n\ngrads = jax.grad(loss_fn)(model, X, Y)\n```\n\nHowever, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.\n\n```python\nmodel.name = 'stainless'\n\n@inox.jit\ndef loss_fn(model, x, y):\n pred = inox.vmap(model)(x)\n return jax.numpy.mean((y - pred) ** 2)\n\ngrads = inox.grad(loss_fn)(model, X, Y)\n```\n\nFor more information, check out the documentation at [inox.readthedocs.io](https://inox.readthedocs.io).\n\n## Contributing\n\nIf you have a question, an issue or would like to contribute, please read our [contributing guidelines](https://github.com/francois-rozet/inox/blob/master/CONTRIBUTING.md).\n",
"bugtrack_url": null,
"license": null,
"summary": "Stainless neural networks in JAX",
"version": "0.5.0",
"project_urls": {
"documentation": "https://inox.readthedocs.io",
"source": "https://github.com/francois-rozet/inox",
"tracker": "https://github.com/francois-rozet/inox/issues"
},
"split_keywords": [
"jax",
" pytree",
" neural networks",
" deep learning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "c969d6f48fa06166135f0922c0e318b47009dae9f86dc2c4fc9cecdfa58d9fe2",
"md5": "7d0840a0f93b743c6690ff866f6ea427",
"sha256": "af8ea976de2d51bfb1ec84fa4809465e8f13d0c5c7a1f861669569751aa658cc"
},
"downloads": -1,
"filename": "inox-0.5.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "7d0840a0f93b743c6690ff866f6ea427",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 33880,
"upload_time": "2024-03-21T12:20:22",
"upload_time_iso_8601": "2024-03-21T12:20:22.209039Z",
"url": "https://files.pythonhosted.org/packages/c9/69/d6f48fa06166135f0922c0e318b47009dae9f86dc2c4fc9cecdfa58d9fe2/inox-0.5.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "f9b616bacb128da4a24cfcfae54ce2a8625546be71273a84b199577905cb7674",
"md5": "7202ef080910e64af2c044218cae618e",
"sha256": "05d750b0d3bc13c367feb57f6648b65977d9e6600b96ec36f31ab0faf70ea336"
},
"downloads": -1,
"filename": "inox-0.5.0.tar.gz",
"has_sig": false,
"md5_digest": "7202ef080910e64af2c044218cae618e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 28837,
"upload_time": "2024-03-21T12:20:24",
"upload_time_iso_8601": "2024-03-21T12:20:24.143167Z",
"url": "https://files.pythonhosted.org/packages/f9/b6/16bacb128da4a24cfcfae54ce2a8625546be71273a84b199577905cb7674/inox-0.5.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-03-21 12:20:24",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "francois-rozet",
"github_project": "inox",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "einops",
"specs": [
[
">=",
"0.5.0"
]
]
},
{
"name": "jax",
"specs": [
[
">=",
"0.4.14"
]
]
}
],
"lcname": "inox"
}