flax


Nameflax JSON
Version 0.8.2 PyPI version JSON
download
home_page
SummaryFlax: A neural network library for JAX designed for flexibility
upload_time2024-03-14 11:35:08
maintainer
docs_urlNone
author
requires_python>=3.9
license
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <div align="center">
<img src="https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png" alt="logo"></img>
</div>

# Flax: A neural network library and ecosystem for JAX designed for flexibility

![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax)


[**Overview**](#overview)
| [**Quick install**](#quick-install)
| [**What does Flax look like?**](#what-does-flax-look-like)
| [**Documentation**](https://flax.readthedocs.io/)

This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).**

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

Flax is being used by a growing
community of hundreds of folks in various Alphabet research departments
for their daily work, as well as a [growing community
of open source
projects](https://github.com/google/flax/network/dependents?dependent_type=REPOSITORY).

The Flax team's mission is to serve the growing JAX neural network
research ecosystem -- both within Alphabet and with the broader community,
and to explore the use-cases where JAX shines. We use GitHub for almost
all of our coordination and planning, as well as where we discuss
upcoming design changes. We welcome feedback on any of our discussion,
issue and pull request threads. We are in the process of moving some
remaining internal design docs and conversation threads to GitHub
discussions, issues and pull requests. We hope to increasingly engage
with the needs and clarifications of the broader ecosystem. Please let
us know how we can help!

Please report any feature requests,
issues, questions or concerns in our [discussion
forum](https://github.com/google/flax/discussions), or just let us
know what you're working on!

We expect to improve Flax, but we don't anticipate significant
breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)
entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

## Overview

Flax is a high-performance neural network library and ecosystem for
JAX that is **designed for flexibility**:
Try new forms of training by forking an example and by modifying the training
loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and
comes with everything you need to start your research, including:

* **Neural network API** (`flax.linen`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device

* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b

## Quick install

You will need Python 3.6 or later, and a working [JAX](https://github.com/google/jax/blob/main/README.md)
installation (with or without GPU support - refer to [the instructions](https://github.com/google/jax/blob/main/README.md)).
For a CPU-only version of JAX:

```
pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only
```

Then, install Flax from PyPi:

```
pip install flax
```

To upgrade to the latest version of Flax, you can use:

```
pip install --upgrade git+https://github.com/google/flax.git
```
To install some additional dependencies (like `matplotlib`) that are required but not included
by some dependencies, you can use:

```bash
pip install "flax[all]"
```

## What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).

```py
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

```py
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

```py
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
```

## 🤗 Hugging Face

In-detail examples to train and evaluate a variety of Flax models for
Natural Language Processing, Computer Vision, and Speech Recognition are
actively maintained in the [🤗 Transformers repository](https://github.com/huggingface/transformers/tree/main/examples/flax).

As of October 2021, the [19 most-used Transformer architectures](https://huggingface.co/transformers/#supported-frameworks) are supported in Flax
and over 5000 pretrained checkpoints in Flax have been uploaded to the [🤗 Hub](https://huggingface.co/models?library=jax&sort=downloads).

## Citing Flax

To cite this repository:

```
@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.8.1},
  year = {2023},
}
```

In the above bibtex entry, names are in alphabetical order, the version number
is intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release.

## Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "flax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "",
    "author": "",
    "author_email": "Flax team <flax-dev@google.com>",
    "download_url": "https://files.pythonhosted.org/packages/21/32/968d96d29f8af0d2ea16de8dcab640cc3b3b534a6d0dcdfdd2739a85fa5c/flax-0.8.2.tar.gz",
    "platform": null,
    "description": "<div align=\"center\">\n<img src=\"https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png\" alt=\"logo\"></img>\n</div>\n\n# Flax: A neural network library and ecosystem for JAX designed for flexibility\n\n![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax)\n\n\n[**Overview**](#overview)\n| [**Quick install**](#quick-install)\n| [**What does Flax look like?**](#what-does-flax-look-like)\n| [**Documentation**](https://flax.readthedocs.io/)\n\nThis README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).**\n\nFlax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.\n\nFlax is being used by a growing\ncommunity of hundreds of folks in various Alphabet research departments\nfor their daily work, as well as a [growing community\nof open source\nprojects](https://github.com/google/flax/network/dependents?dependent_type=REPOSITORY).\n\nThe Flax team's mission is to serve the growing JAX neural network\nresearch ecosystem -- both within Alphabet and with the broader community,\nand to explore the use-cases where JAX shines. We use GitHub for almost\nall of our coordination and planning, as well as where we discuss\nupcoming design changes. We welcome feedback on any of our discussion,\nissue and pull request threads. We are in the process of moving some\nremaining internal design docs and conversation threads to GitHub\ndiscussions, issues and pull requests. We hope to increasingly engage\nwith the needs and clarifications of the broader ecosystem. Please let\nus know how we can help!\n\nPlease report any feature requests,\nissues, questions or concerns in our [discussion\nforum](https://github.com/google/flax/discussions), or just let us\nknow what you're working on!\n\nWe expect to improve Flax, but we don't anticipate significant\nbreaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)\nentries and deprecation warnings when possible.\n\nIn case you want to reach us directly, we're at flax-dev@google.com.\n\n## Overview\n\nFlax is a high-performance neural network library and ecosystem for\nJAX that is **designed for flexibility**:\nTry new forms of training by forking an example and by modifying the training\nloop, not by adding features to a framework.\n\nFlax is being developed in close collaboration with the JAX team and\ncomes with everything you need to start your research, including:\n\n* **Neural network API** (`flax.linen`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout\n\n* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device\n\n* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging\n\n* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b\n\n## Quick install\n\nYou will need Python 3.6 or later, and a working [JAX](https://github.com/google/jax/blob/main/README.md)\ninstallation (with or without GPU support - refer to [the instructions](https://github.com/google/jax/blob/main/README.md)).\nFor a CPU-only version of JAX:\n\n```\npip install --upgrade pip # To support manylinux2010 wheels.\npip install --upgrade jax jaxlib # CPU-only\n```\n\nThen, install Flax from PyPi:\n\n```\npip install flax\n```\n\nTo upgrade to the latest version of Flax, you can use:\n\n```\npip install --upgrade git+https://github.com/google/flax.git\n```\nTo install some additional dependencies (like `matplotlib`) that are required but not included\nby some dependencies, you can use:\n\n```bash\npip install \"flax[all]\"\n```\n\n## What does Flax look like?\n\nWe provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.\n\nTo learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our\n[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and\n[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).\n\n```py\nfrom typing import Sequence\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nimport flax.linen as nn\n\nclass MLP(nn.Module):\n  features: Sequence[int]\n\n  @nn.compact\n  def __call__(self, x):\n    for feat in self.features[:-1]:\n      x = nn.relu(nn.Dense(feat)(x))\n    x = nn.Dense(self.features[-1])(x)\n    return x\n\nmodel = MLP([12, 8, 4])\nbatch = jnp.ones((32, 10))\nvariables = model.init(jax.random.key(0), batch)\noutput = model.apply(variables, batch)\n```\n\n```py\nclass CNN(nn.Module):\n  @nn.compact\n  def __call__(self, x):\n    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = x.reshape((x.shape[0], -1))  # flatten\n    x = nn.Dense(features=256)(x)\n    x = nn.relu(x)\n    x = nn.Dense(features=10)(x)\n    x = nn.log_softmax(x)\n    return x\n\nmodel = CNN()\nbatch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format\nvariables = model.init(jax.random.key(0), batch)\noutput = model.apply(variables, batch)\n```\n\n```py\nclass AutoEncoder(nn.Module):\n  encoder_widths: Sequence[int]\n  decoder_widths: Sequence[int]\n  input_shape: Sequence[int]\n\n  def setup(self):\n    input_dim = np.prod(self.input_shape)\n    self.encoder = MLP(self.encoder_widths)\n    self.decoder = MLP(self.decoder_widths + (input_dim,))\n\n  def __call__(self, x):\n    return self.decode(self.encode(x))\n\n  def encode(self, x):\n    assert x.shape[1:] == self.input_shape\n    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))\n\n  def decode(self, z):\n    z = self.decoder(z)\n    x = nn.sigmoid(z)\n    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)\n    return x\n\nmodel = AutoEncoder(encoder_widths=[20, 10, 5],\n                    decoder_widths=[5, 10, 20],\n                    input_shape=(12,))\nbatch = jnp.ones((16, 12))\nvariables = model.init(jax.random.key(0), batch)\nencoded = model.apply(variables, batch, method=model.encode)\ndecoded = model.apply(variables, encoded, method=model.decode)\n```\n\n## \ud83e\udd17 Hugging Face\n\nIn-detail examples to train and evaluate a variety of Flax models for\nNatural Language Processing, Computer Vision, and Speech Recognition are\nactively maintained in the [\ud83e\udd17 Transformers repository](https://github.com/huggingface/transformers/tree/main/examples/flax).\n\nAs of October 2021, the [19 most-used Transformer architectures](https://huggingface.co/transformers/#supported-frameworks) are supported in Flax\nand over 5000 pretrained checkpoints in Flax have been uploaded to the [\ud83e\udd17 Hub](https://huggingface.co/models?library=jax&sort=downloads).\n\n## Citing Flax\n\nTo cite this repository:\n\n```\n@software{flax2020github,\n  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},\n  title = {{F}lax: A neural network library and ecosystem for {JAX}},\n  url = {http://github.com/google/flax},\n  version = {0.8.1},\n  year = {2023},\n}\n```\n\nIn the above bibtex entry, names are in alphabetical order, the version number\nis intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release.\n\n## Note\n\nFlax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Flax: A neural network library for JAX designed for flexibility",
    "version": "0.8.2",
    "project_urls": {
        "homepage": "https://github.com/google/flax"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "b99259b0a2b5df281206433fa6496b176e95249eb0a8192586f88309d7d5df27",
                "md5": "620eae2a401547625eef7e764fb0b13f",
                "sha256": "911d83e01380fdb3135c309e70981aabd15e7ca038014d7989ddc3cfaf4d0d45"
            },
            "downloads": -1,
            "filename": "flax-0.8.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "620eae2a401547625eef7e764fb0b13f",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 686754,
            "upload_time": "2024-03-14T11:35:04",
            "upload_time_iso_8601": "2024-03-14T11:35:04.988424Z",
            "url": "https://files.pythonhosted.org/packages/b9/92/59b0a2b5df281206433fa6496b176e95249eb0a8192586f88309d7d5df27/flax-0.8.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "2132968d96d29f8af0d2ea16de8dcab640cc3b3b534a6d0dcdfdd2739a85fa5c",
                "md5": "e11dd8a5facb04517665fd0f13c775ed",
                "sha256": "1c4e43ac3cb32e8e15c733cfa3df8d827b61d9ce29b50a7035920bfe9fdaa5b0"
            },
            "downloads": -1,
            "filename": "flax-0.8.2.tar.gz",
            "has_sig": false,
            "md5_digest": "e11dd8a5facb04517665fd0f13c775ed",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 2553188,
            "upload_time": "2024-03-14T11:35:08",
            "upload_time_iso_8601": "2024-03-14T11:35:08.295128Z",
            "url": "https://files.pythonhosted.org/packages/21/32/968d96d29f8af0d2ea16de8dcab640cc3b3b534a6d0dcdfdd2739a85fa5c/flax-0.8.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-03-14 11:35:08",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google",
    "github_project": "flax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "flax"
}
        
Elapsed time: 0.18586s