flax


Nameflax JSON
Version 0.10.3 PyPI version JSON
download
home_pageNone
SummaryFlax: A neural network library for JAX designed for flexibility
upload_time2025-02-10 17:34:17
maintainerNone
docs_urlNone
authorNone
requires_python>=3.10
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <div align="center">
<img src="https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png" alt="logo"></img>
</div>

# Flax: A neural network library and ecosystem for JAX designed for flexibility

![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax)

[**Overview**](#overview)
| [**Quick install**](#quick-install)
| [**What does Flax look like?**](#what-does-flax-look-like)
| [**Documentation**](https://flax.readthedocs.io/)

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make
it easier to create, inspect, debug, and analyze neural networks in
[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support
for Python reference semantics. This allows users to express their models using
regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which
was released in 2020 by engineers and researchers at Google Brain in close collaboration
with the JAX team.

You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out:

* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html)
* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html)
* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)
* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html)

**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/).

The Flax team's mission is to serve the growing JAX neural network
research ecosystem - both within Alphabet and with the broader community,
and to explore the use-cases where JAX shines. We use GitHub for almost
all of our coordination and planning, as well as where we discuss
upcoming design changes. We welcome feedback on any of our discussion,
issue and pull request threads.

You can make feature requests, let us know what you are working on,
report issues, ask questions in our [Flax GitHub discussion
forum](https://github.com/google/flax/discussions).

We expect to improve Flax, but we don't anticipate significant
breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)
entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

## Overview

Flax is a high-performance neural network library and ecosystem for
JAX that is **designed for flexibility**:
Try new forms of training by forking an example and by modifying the training
loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and
comes with everything you need to start your research, including:

* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout).

* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device.

* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma), [Transformer LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx).

## Quick install

Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html).

You will need Python 3.8 or later. Install Flax from PyPi:

```
pip install flax
```

To upgrade to the latest version of Flax, you can use:

```
pip install --upgrade git+https://github.com/google/flax.git
```

To install some additional dependencies (like `matplotlib`) that are required but not included
by some dependencies, you can use:

```bash
pip install "flax[all]"
```

## What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).

Example of an MLP:

```py
class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)
```

Example of a CNN:

```py
class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x
```

Example of an autoencoder:


```py
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)
```

## Citing Flax

To cite this repository:

```
@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.10.3},
  year = {2024},
}
```

In the above bibtex entry, names are in alphabetical order, the version number
is intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release.

## Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "flax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.10",
    "maintainer_email": null,
    "keywords": null,
    "author": null,
    "author_email": "Flax team <flax-dev@google.com>",
    "download_url": "https://files.pythonhosted.org/packages/95/63/79b204d9f99e855ff36207d4810b4932f83dc7c100e9512eebe1466f0c4f/flax-0.10.3.tar.gz",
    "platform": null,
    "description": "<div align=\"center\">\n<img src=\"https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png\" alt=\"logo\"></img>\n</div>\n\n# Flax: A neural network library and ecosystem for JAX designed for flexibility\n\n![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax)\n\n[**Overview**](#overview)\n| [**Quick install**](#quick-install)\n| [**What does Flax look like?**](#what-does-flax-look-like)\n| [**Documentation**](https://flax.readthedocs.io/)\n\nReleased in 2024, Flax NNX is a new simplified Flax API that is designed to make\nit easier to create, inspect, debug, and analyze neural networks in\n[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support\nfor Python reference semantics. This allows users to express their models using\nregular Python objects, enabling reference sharing and mutability.\n\nFlax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which\nwas released in 2020 by engineers and researchers at Google Brain in close collaboration\nwith the JAX team.\n\nYou can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out:\n\n* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html)\n* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html)\n* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)\n* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html)\n\n**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/).\n\nThe Flax team's mission is to serve the growing JAX neural network\nresearch ecosystem - both within Alphabet and with the broader community,\nand to explore the use-cases where JAX shines. We use GitHub for almost\nall of our coordination and planning, as well as where we discuss\nupcoming design changes. We welcome feedback on any of our discussion,\nissue and pull request threads.\n\nYou can make feature requests, let us know what you are working on,\nreport issues, ask questions in our [Flax GitHub discussion\nforum](https://github.com/google/flax/discussions).\n\nWe expect to improve Flax, but we don't anticipate significant\nbreaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)\nentries and deprecation warnings when possible.\n\nIn case you want to reach us directly, we're at flax-dev@google.com.\n\n## Overview\n\nFlax is a high-performance neural network library and ecosystem for\nJAX that is **designed for flexibility**:\nTry new forms of training by forking an example and by modifying the training\nloop, not adding features to a framework.\n\nFlax is being developed in close collaboration with the JAX team and\ncomes with everything you need to start your research, including:\n\n* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout).\n\n* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device.\n\n* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma), [Transformer LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx).\n\n## Quick install\n\nFlax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html).\n\nYou will need Python 3.8 or later. Install Flax from PyPi:\n\n```\npip install flax\n```\n\nTo upgrade to the latest version of Flax, you can use:\n\n```\npip install --upgrade git+https://github.com/google/flax.git\n```\n\nTo install some additional dependencies (like `matplotlib`) that are required but not included\nby some dependencies, you can use:\n\n```bash\npip install \"flax[all]\"\n```\n\n## What does Flax look like?\n\nWe provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.\n\nTo learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our\n[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and\n[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).\n\nExample of an MLP:\n\n```py\nclass MLP(nnx.Module):\n  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n    self.linear1 = Linear(din, dmid, rngs=rngs)\n    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)\n    self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n    self.linear2 = Linear(dmid, dout, rngs=rngs)\n\n  def __call__(self, x: jax.Array):\n    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))\n    return self.linear2(x)\n```\n\nExample of a CNN:\n\n```py\nclass CNN(nnx.Module):\n  def __init__(self, *, rngs: nnx.Rngs):\n    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))\n    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)\n    self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n\n  def __call__(self, x):\n    x = self.avg_pool(nnx.relu(self.conv1(x)))\n    x = self.avg_pool(nnx.relu(self.conv2(x)))\n    x = x.reshape(x.shape[0], -1)  # flatten\n    x = nnx.relu(self.linear1(x))\n    x = self.linear2(x)\n    return x\n```\n\nExample of an autoencoder:\n\n\n```py\nEncoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)\nDecoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)\n\nclass AutoEncoder(nnx.Module):\n  def __init__(self, rngs):\n    self.encoder = Encoder(rngs)\n    self.decoder = Decoder(rngs)\n\n  def __call__(self, x) -> jax.Array:\n    return self.decoder(self.encoder(x))\n\n  def encode(self, x) -> jax.Array:\n    return self.encoder(x)\n```\n\n## Citing Flax\n\nTo cite this repository:\n\n```\n@software{flax2020github,\n  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},\n  title = {{F}lax: A neural network library and ecosystem for {JAX}},\n  url = {http://github.com/google/flax},\n  version = {0.10.3},\n  year = {2024},\n}\n```\n\nIn the above bibtex entry, names are in alphabetical order, the version number\nis intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release.\n\n## Note\n\nFlax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Flax: A neural network library for JAX designed for flexibility",
    "version": "0.10.3",
    "project_urls": {
        "homepage": "https://github.com/google/flax"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "77285d84597a2b1703486b554ab430f1e22f5f26a732ad98c398a8b414cd22c4",
                "md5": "b79ceffb867adcca1f42684b01b401a4",
                "sha256": "7158b5dd6a05837e662a1ce1beea7adbad6d3612c0551c986b1c0a56071e3021"
            },
            "downloads": -1,
            "filename": "flax-0.10.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "b79ceffb867adcca1f42684b01b401a4",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10",
            "size": 435734,
            "upload_time": "2025-02-10T17:34:15",
            "upload_time_iso_8601": "2025-02-10T17:34:15.117881Z",
            "url": "https://files.pythonhosted.org/packages/77/28/5d84597a2b1703486b554ab430f1e22f5f26a732ad98c398a8b414cd22c4/flax-0.10.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "956379b204d9f99e855ff36207d4810b4932f83dc7c100e9512eebe1466f0c4f",
                "md5": "626f560a332bf80fd37c8add0b137bb1",
                "sha256": "29cde8cf05ffbff39b7f7167f0fe9916694cce76ce4c14e8be3549c1fd1b7c81"
            },
            "downloads": -1,
            "filename": "flax-0.10.3.tar.gz",
            "has_sig": false,
            "md5_digest": "626f560a332bf80fd37c8add0b137bb1",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10",
            "size": 5181656,
            "upload_time": "2025-02-10T17:34:17",
            "upload_time_iso_8601": "2025-02-10T17:34:17.363962Z",
            "url": "https://files.pythonhosted.org/packages/95/63/79b204d9f99e855ff36207d4810b4932f83dc7c100e9512eebe1466f0c4f/flax-0.10.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-02-10 17:34:17",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google",
    "github_project": "flax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "flax"
}
        
Elapsed time: 8.23805s