Name | reax JSON |
Version |
0.3.1
JSON |
| download |
home_page | None |
Summary | REAX: A simple training framework for JAX-based projects |
upload_time | 2025-01-24 12:00:41 |
maintainer | None |
docs_url | None |
author | None |
requires_python | >=3.9 |
license | None |
keywords |
machine learning
jax
research
|
VCS |
 |
bugtrack_url |
|
requirements |
No requirements were recorded.
|
Travis-CI |
No Travis.
|
coveralls test coverage |
No coveralls.
|
REAX
====
.. image:: https://codecov.io/gh/muhrin/reax/branch/develop/graph/badge.svg
:target: https://codecov.io/gh/muhrin/reax
:alt: Coverage
.. image:: https://github.com/muhrin/reax/actions/workflows/ci.yml/badge.svg
:target: https://github.com/muhrin/reax/actions/workflows/ci.yml
:alt: Tests
.. image:: https://img.shields.io/pypi/v/reax.svg
:target: https://pypi.python.org/pypi/reax/
:alt: Latest Version
.. image:: https://img.shields.io/pypi/wheel/reax.svg
:target: https://pypi.python.org/pypi/reax/
.. image:: https://img.shields.io/pypi/pyversions/reax.svg
:target: https://pypi.python.org/pypi/reax/
.. image:: https://img.shields.io/pypi/l/reax.svg
:target: https://pypi.python.org/pypi/reax/
REAX: A simple training framework for JAX-based projects
REAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and
customizability to the world of training JAX models. Much of lightning's API has been adopted
with some modifications being made to accommodate JAX's pure function based approach.
Quick start
-----------
.. code-block:: shell
pip install reax
REAX example
------------
Define the training workflow. Here's a toy example:
.. code-block:: python
# main.py
# ! pip install torchvision
from functools import partial
import jax, optax, reax, flax.linen as linen
import torch.utils.data as data, torchvision as tv
class Autoencoder(linen.Module):
def setup(self):
super().__init__()
self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])
def __call__(self, x):
z = self.encoder(x)
return self.decoder(z)
# --------------------------------
# Step 1: Define a REAX Module
# --------------------------------
# A ReaxModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).
class ReaxAutoEncoder(reax.Module):
def __init__(self):
super().__init__()
self.ae = Autoencoder()
def setup(self, stage: "reax.Stage", batch) -> None:
if self.parameters() is None:
x = batch[0].reshape(len(batch[0]), -1)
params = self.ae.init(self.rng_key(), x)
self.set_parameters(params)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, x):
embedding = jax.jit(self.ae.encoder.apply)(self.parameters()["params"]["encoder"], x)
return embedding
def training_step(self, batch, batch_idx):
x = batch[0].reshape(len(batch[0]), -1)
loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
self.log("train_loss", loss, on_step=True, prog_bar=True)
return loss, grads
@staticmethod
@partial(jax.jit, static_argnums=2)
def loss_fn(params, x, model):
predictions = model.apply(params, x)
return optax.losses.squared_error(predictions, x).mean()
def configure_optimizers(self):
opt = optax.adam(learning_rate=1e-3)
state = opt.init(self.parameters())
return opt, state
# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=jax.numpy.asarray)
train, val = data.random_split(dataset, [55000, 5000])
# -------------------
# Step 3: Train
# -------------------
autoencoder = ReaxAutoEncoder()
trainer = reax.Trainer(autoencoder)
trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))
Here, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models
there's no need to use this or pytorch at all.
Run the model on the terminal
.. code-block:: bash
pip install reax torchvision
python main.py
Raw data
{
"_id": null,
"home_page": null,
"name": "reax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "machine learning, jax, research",
"author": null,
"author_email": "Martin Uhrin <martin.uhrin.10@ucl.ac.uk>",
"download_url": "https://files.pythonhosted.org/packages/9c/18/f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6/reax-0.3.1.tar.gz",
"platform": null,
"description": "\nREAX\n====\n\n.. image:: https://codecov.io/gh/muhrin/reax/branch/develop/graph/badge.svg\n :target: https://codecov.io/gh/muhrin/reax\n :alt: Coverage\n\n.. image:: https://github.com/muhrin/reax/actions/workflows/ci.yml/badge.svg\n :target: https://github.com/muhrin/reax/actions/workflows/ci.yml\n :alt: Tests\n\n.. image:: https://img.shields.io/pypi/v/reax.svg\n :target: https://pypi.python.org/pypi/reax/\n :alt: Latest Version\n\n.. image:: https://img.shields.io/pypi/wheel/reax.svg\n :target: https://pypi.python.org/pypi/reax/\n\n.. image:: https://img.shields.io/pypi/pyversions/reax.svg\n :target: https://pypi.python.org/pypi/reax/\n\n.. image:: https://img.shields.io/pypi/l/reax.svg\n :target: https://pypi.python.org/pypi/reax/\n\n\nREAX: A simple training framework for JAX-based projects\n\nREAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and\ncustomizability to the world of training JAX models. Much of lightning's API has been adopted\nwith some modifications being made to accommodate JAX's pure function based approach.\n\n\nQuick start\n-----------\n\n.. code-block:: shell\n\n pip install reax\n\n\nREAX example\n------------\n\nDefine the training workflow. Here's a toy example:\n\n.. code-block:: python\n\n # main.py\n # ! pip install torchvision\n from functools import partial\n import jax, optax, reax, flax.linen as linen\n import torch.utils.data as data, torchvision as tv\n\n\n class Autoencoder(linen.Module):\n def setup(self):\n super().__init__()\n self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])\n self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])\n\n def __call__(self, x):\n z = self.encoder(x)\n return self.decoder(z)\n\n\n # --------------------------------\n # Step 1: Define a REAX Module\n # --------------------------------\n # A ReaxModule (nn.Module subclass) defines a full *system*\n # (ie: an LLM, diffusion model, autoencoder, or simple image classifier).\n\n\n class ReaxAutoEncoder(reax.Module):\n def __init__(self):\n super().__init__()\n self.ae = Autoencoder()\n\n def setup(self, stage: \"reax.Stage\", batch) -> None:\n if self.parameters() is None:\n x = batch[0].reshape(len(batch[0]), -1)\n params = self.ae.init(self.rng_key(), x)\n self.set_parameters(params)\n\n def __call__(self, *args, **kwargs):\n return self.forward(*args, **kwargs)\n\n def forward(self, x):\n embedding = jax.jit(self.ae.encoder.apply)(self.parameters()[\"params\"][\"encoder\"], x)\n return embedding\n\n def training_step(self, batch, batch_idx):\n x = batch[0].reshape(len(batch[0]), -1)\n loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)\n self.log(\"train_loss\", loss, on_step=True, prog_bar=True)\n return loss, grads\n\n @staticmethod\n @partial(jax.jit, static_argnums=2)\n def loss_fn(params, x, model):\n predictions = model.apply(params, x)\n return optax.losses.squared_error(predictions, x).mean()\n\n def configure_optimizers(self):\n opt = optax.adam(learning_rate=1e-3)\n state = opt.init(self.parameters())\n return opt, state\n\n\n # -------------------\n # Step 2: Define data\n # -------------------\n dataset = tv.datasets.MNIST(\".\", download=True, transform=jax.numpy.asarray)\n train, val = data.random_split(dataset, [55000, 5000])\n\n # -------------------\n # Step 3: Train\n # -------------------\n autoencoder = ReaxAutoEncoder()\n trainer = reax.Trainer(autoencoder)\n trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))\n\nHere, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models\nthere's no need to use this or pytorch at all.\nRun the model on the terminal\n\n\n.. code-block:: bash\n\n pip install reax torchvision\n python main.py\n",
"bugtrack_url": null,
"license": null,
"summary": "REAX: A simple training framework for JAX-based projects",
"version": "0.3.1",
"project_urls": {
"Home": "https://github.com/muhrin/reax",
"Source": "https://github.com/muhrin/reax"
},
"split_keywords": [
"machine learning",
" jax",
" research"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "47d5792145d9604b218b1b51300a94ccfd2d8ece3d9900ee28a5a653bbdac384",
"md5": "eef530a0f8652f3f23f404e17f8aea89",
"sha256": "68318043c72e67a40a5038470fa4de55e91db00d4af9e917b191c5da2d4b18bc"
},
"downloads": -1,
"filename": "reax-0.3.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "eef530a0f8652f3f23f404e17f8aea89",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 111971,
"upload_time": "2025-01-24T12:00:38",
"upload_time_iso_8601": "2025-01-24T12:00:38.903076Z",
"url": "https://files.pythonhosted.org/packages/47/d5/792145d9604b218b1b51300a94ccfd2d8ece3d9900ee28a5a653bbdac384/reax-0.3.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "9c18f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6",
"md5": "e961e0e7535de43ef51d82b54fb52c18",
"sha256": "dfed824d3d9d677e107fe46fdabf904b032d44869503e09129e02d6989f80056"
},
"downloads": -1,
"filename": "reax-0.3.1.tar.gz",
"has_sig": false,
"md5_digest": "e961e0e7535de43ef51d82b54fb52c18",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 85949,
"upload_time": "2025-01-24T12:00:41",
"upload_time_iso_8601": "2025-01-24T12:00:41.488222Z",
"url": "https://files.pythonhosted.org/packages/9c/18/f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6/reax-0.3.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-01-24 12:00:41",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "muhrin",
"github_project": "reax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "reax"
}