reax


Namereax JSON
Version 0.3.1 PyPI version JSON
download
home_pageNone
SummaryREAX: A simple training framework for JAX-based projects
upload_time2025-01-24 12:00:41
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords machine learning jax research
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            
REAX
====

.. image:: https://codecov.io/gh/muhrin/reax/branch/develop/graph/badge.svg
    :target: https://codecov.io/gh/muhrin/reax
    :alt: Coverage

.. image:: https://github.com/muhrin/reax/actions/workflows/ci.yml/badge.svg
    :target: https://github.com/muhrin/reax/actions/workflows/ci.yml
    :alt: Tests

.. image:: https://img.shields.io/pypi/v/reax.svg
    :target: https://pypi.python.org/pypi/reax/
    :alt: Latest Version

.. image:: https://img.shields.io/pypi/wheel/reax.svg
    :target: https://pypi.python.org/pypi/reax/

.. image:: https://img.shields.io/pypi/pyversions/reax.svg
    :target: https://pypi.python.org/pypi/reax/

.. image:: https://img.shields.io/pypi/l/reax.svg
    :target: https://pypi.python.org/pypi/reax/


REAX: A simple training framework for JAX-based projects

REAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and
customizability to the world of training JAX models. Much of lightning's API has been adopted
with some modifications being made to accommodate JAX's pure function based approach.


Quick start
-----------

.. code-block:: shell

    pip install reax


REAX example
------------

Define the training workflow. Here's a toy example:

.. code-block:: python

    # main.py
    # ! pip install torchvision
    from functools import partial
    import jax, optax, reax, flax.linen as linen
    import torch.utils.data as data, torchvision as tv


    class Autoencoder(linen.Module):
        def setup(self):
            super().__init__()
            self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
            self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])

        def __call__(self, x):
            z = self.encoder(x)
            return self.decoder(z)


    # --------------------------------
    # Step 1: Define a REAX Module
    # --------------------------------
    # A ReaxModule (nn.Module subclass) defines a full *system*
    # (ie: an LLM, diffusion model, autoencoder, or simple image classifier).


    class ReaxAutoEncoder(reax.Module):
        def __init__(self):
            super().__init__()
            self.ae = Autoencoder()

        def setup(self, stage: "reax.Stage", batch) -> None:
            if self.parameters() is None:
                x = batch[0].reshape(len(batch[0]), -1)
                params = self.ae.init(self.rng_key(), x)
                self.set_parameters(params)

        def __call__(self, *args, **kwargs):
            return self.forward(*args, **kwargs)

        def forward(self, x):
            embedding = jax.jit(self.ae.encoder.apply)(self.parameters()["params"]["encoder"], x)
            return embedding

        def training_step(self, batch, batch_idx):
            x = batch[0].reshape(len(batch[0]), -1)
            loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
            self.log("train_loss", loss, on_step=True, prog_bar=True)
            return loss, grads

        @staticmethod
        @partial(jax.jit, static_argnums=2)
        def loss_fn(params, x, model):
            predictions = model.apply(params, x)
            return optax.losses.squared_error(predictions, x).mean()

        def configure_optimizers(self):
            opt = optax.adam(learning_rate=1e-3)
            state = opt.init(self.parameters())
            return opt, state


    # -------------------
    # Step 2: Define data
    # -------------------
    dataset = tv.datasets.MNIST(".", download=True, transform=jax.numpy.asarray)
    train, val = data.random_split(dataset, [55000, 5000])

    # -------------------
    # Step 3: Train
    # -------------------
    autoencoder = ReaxAutoEncoder()
    trainer = reax.Trainer(autoencoder)
    trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))

Here, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models
there's no need to use this or pytorch at all.
Run the model on the terminal


.. code-block:: bash

    pip install reax torchvision
    python main.py

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "reax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "machine learning, jax, research",
    "author": null,
    "author_email": "Martin Uhrin <martin.uhrin.10@ucl.ac.uk>",
    "download_url": "https://files.pythonhosted.org/packages/9c/18/f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6/reax-0.3.1.tar.gz",
    "platform": null,
    "description": "\nREAX\n====\n\n.. image:: https://codecov.io/gh/muhrin/reax/branch/develop/graph/badge.svg\n    :target: https://codecov.io/gh/muhrin/reax\n    :alt: Coverage\n\n.. image:: https://github.com/muhrin/reax/actions/workflows/ci.yml/badge.svg\n    :target: https://github.com/muhrin/reax/actions/workflows/ci.yml\n    :alt: Tests\n\n.. image:: https://img.shields.io/pypi/v/reax.svg\n    :target: https://pypi.python.org/pypi/reax/\n    :alt: Latest Version\n\n.. image:: https://img.shields.io/pypi/wheel/reax.svg\n    :target: https://pypi.python.org/pypi/reax/\n\n.. image:: https://img.shields.io/pypi/pyversions/reax.svg\n    :target: https://pypi.python.org/pypi/reax/\n\n.. image:: https://img.shields.io/pypi/l/reax.svg\n    :target: https://pypi.python.org/pypi/reax/\n\n\nREAX: A simple training framework for JAX-based projects\n\nREAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and\ncustomizability to the world of training JAX models. Much of lightning's API has been adopted\nwith some modifications being made to accommodate JAX's pure function based approach.\n\n\nQuick start\n-----------\n\n.. code-block:: shell\n\n    pip install reax\n\n\nREAX example\n------------\n\nDefine the training workflow. Here's a toy example:\n\n.. code-block:: python\n\n    # main.py\n    # ! pip install torchvision\n    from functools import partial\n    import jax, optax, reax, flax.linen as linen\n    import torch.utils.data as data, torchvision as tv\n\n\n    class Autoencoder(linen.Module):\n        def setup(self):\n            super().__init__()\n            self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])\n            self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])\n\n        def __call__(self, x):\n            z = self.encoder(x)\n            return self.decoder(z)\n\n\n    # --------------------------------\n    # Step 1: Define a REAX Module\n    # --------------------------------\n    # A ReaxModule (nn.Module subclass) defines a full *system*\n    # (ie: an LLM, diffusion model, autoencoder, or simple image classifier).\n\n\n    class ReaxAutoEncoder(reax.Module):\n        def __init__(self):\n            super().__init__()\n            self.ae = Autoencoder()\n\n        def setup(self, stage: \"reax.Stage\", batch) -> None:\n            if self.parameters() is None:\n                x = batch[0].reshape(len(batch[0]), -1)\n                params = self.ae.init(self.rng_key(), x)\n                self.set_parameters(params)\n\n        def __call__(self, *args, **kwargs):\n            return self.forward(*args, **kwargs)\n\n        def forward(self, x):\n            embedding = jax.jit(self.ae.encoder.apply)(self.parameters()[\"params\"][\"encoder\"], x)\n            return embedding\n\n        def training_step(self, batch, batch_idx):\n            x = batch[0].reshape(len(batch[0]), -1)\n            loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)\n            self.log(\"train_loss\", loss, on_step=True, prog_bar=True)\n            return loss, grads\n\n        @staticmethod\n        @partial(jax.jit, static_argnums=2)\n        def loss_fn(params, x, model):\n            predictions = model.apply(params, x)\n            return optax.losses.squared_error(predictions, x).mean()\n\n        def configure_optimizers(self):\n            opt = optax.adam(learning_rate=1e-3)\n            state = opt.init(self.parameters())\n            return opt, state\n\n\n    # -------------------\n    # Step 2: Define data\n    # -------------------\n    dataset = tv.datasets.MNIST(\".\", download=True, transform=jax.numpy.asarray)\n    train, val = data.random_split(dataset, [55000, 5000])\n\n    # -------------------\n    # Step 3: Train\n    # -------------------\n    autoencoder = ReaxAutoEncoder()\n    trainer = reax.Trainer(autoencoder)\n    trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))\n\nHere, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models\nthere's no need to use this or pytorch at all.\nRun the model on the terminal\n\n\n.. code-block:: bash\n\n    pip install reax torchvision\n    python main.py\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "REAX: A simple training framework for JAX-based projects",
    "version": "0.3.1",
    "project_urls": {
        "Home": "https://github.com/muhrin/reax",
        "Source": "https://github.com/muhrin/reax"
    },
    "split_keywords": [
        "machine learning",
        " jax",
        " research"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "47d5792145d9604b218b1b51300a94ccfd2d8ece3d9900ee28a5a653bbdac384",
                "md5": "eef530a0f8652f3f23f404e17f8aea89",
                "sha256": "68318043c72e67a40a5038470fa4de55e91db00d4af9e917b191c5da2d4b18bc"
            },
            "downloads": -1,
            "filename": "reax-0.3.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "eef530a0f8652f3f23f404e17f8aea89",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 111971,
            "upload_time": "2025-01-24T12:00:38",
            "upload_time_iso_8601": "2025-01-24T12:00:38.903076Z",
            "url": "https://files.pythonhosted.org/packages/47/d5/792145d9604b218b1b51300a94ccfd2d8ece3d9900ee28a5a653bbdac384/reax-0.3.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "9c18f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6",
                "md5": "e961e0e7535de43ef51d82b54fb52c18",
                "sha256": "dfed824d3d9d677e107fe46fdabf904b032d44869503e09129e02d6989f80056"
            },
            "downloads": -1,
            "filename": "reax-0.3.1.tar.gz",
            "has_sig": false,
            "md5_digest": "e961e0e7535de43ef51d82b54fb52c18",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 85949,
            "upload_time": "2025-01-24T12:00:41",
            "upload_time_iso_8601": "2025-01-24T12:00:41.488222Z",
            "url": "https://files.pythonhosted.org/packages/9c/18/f5016c6bc07d44f8d743cee9a2315855277fa59bbf5b2651bef72698b2f6/reax-0.3.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-01-24 12:00:41",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "muhrin",
    "github_project": "reax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "reax"
}
        
Elapsed time: 1.78743s