dm-haiku


Namedm-haiku JSON
Version 0.0.11 PyPI version JSON
download
home_pagehttps://github.com/deepmind/dm-haiku
SummaryHaiku is a library for building neural networks in JAX.
upload_time2023-11-10 14:04:38
maintainer
docs_urlNone
authorDeepMind
requires_python
licenseApache 2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Haiku: [Sonnet] for [JAX]

[**Overview**](#overview)
| [**Why Haiku?**](#why-haiku)
| [**Quickstart**](#quickstart)
| [**Installation**](#installation)
| [**Examples**](https://github.com/deepmind/dm-haiku/tree/main/examples/)
| [**User manual**](#user-manual)
| [**Documentation**](https://dm-haiku.readthedocs.io/)
| [**Citing Haiku**](#citing-haiku)

![pytest](https://github.com/deepmind/dm-haiku/workflows/pytest/badge.svg)
![docs](https://readthedocs.org/projects/dm-haiku/badge/?version=latest)
![pypi](https://img.shields.io/pypi/v/dm-haiku)

> [!IMPORTANT]
> 📣 **As of July 2023 [Google DeepMind] recommends that new projects adopt
> [Flax] (a neural network library originally developed by [Google Brain] and
> now by [Google DeepMind]) instead of Haiku.** 📣
>
> At the time of writing [Flax] has superset of the features available in Haiku,
> a [larger](https://github.com/google/flax/graphs/contributors) and
> [more active](https://github.com/google/flax/activity) development team and
> more adoption with users outside of Alphabet. [Flax] has
> [more extensive documentation](https://flax.readthedocs.io/),
> [examples](https://github.com/huggingface/transformers/tree/main/examples/flax)
> and an [active community](https://huggingface.co/flax-community) creating end
> to end examples.
>
> Haiku will remain best-effort supported, however the project will enter
> [maintenance mode](https://en.wikipedia.org/wiki/Maintenance_mode), meaning
> that development efforts will be focussed on bug fixes and compatibility with
> new releases of JAX.
>
> New releases will be made to keep Haiku working with newer versions of Python
> and [JAX], however we will not be adding (or accepting PRs for) new features.
>
> We have significant usage of Haiku internally at [Google DeepMind] and
> currently plan to support Haiku in this mode indefinitely.

## What is Haiku?

> Haiku is a tool<br>
> For building neural networks<br>
> Think: "[Sonnet] for [JAX]"

Haiku is a simple neural network library for [JAX] developed by some of the
authors of [Sonnet], a neural network library for [TensorFlow].

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

**Disambiguation:** if you are looking for Haiku the operating system then
please see https://haiku-os.org/.

## Overview<a id="overview"></a>

[JAX] is a numerical computing library that combines NumPy, automatic
differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use
familiar **object-oriented programming models** while allowing full access to
JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, `hk.Module`, and a simple
function transformation, `hk.transform`.

`hk.Module`s are Python objects that hold references to their own parameters,
other modules, and methods that apply functions on user inputs.

`hk.transform` turns functions that use these object-oriented, functionally
"impure" modules into pure functions that can be used with `jax.jit`,
`jax.grad`, `jax.pmap`, etc.

## Why Haiku?<a id="why-haiku"></a>

There are a number of neural network libraries for JAX. Why should you choose
Haiku?

### Haiku has been tested by researchers at DeepMind at scale.

- DeepMind has reproduced a number of experiments in Haiku and JAX with relative
  ease. These include large-scale results in image and language processing,
  generative models, and reinforcement learning.

### Haiku is a library, not a framework.

- Haiku is designed to make specific things simpler: managing model parameters
  and other model state.
- Haiku can be expected to compose with other libraries and work well with the
  rest of JAX.
- Haiku otherwise is designed to get out of your way - it does not define custom
  optimizers, checkpointing formats, or replication APIs.

### Haiku does not reinvent the wheel.

- Haiku builds on the programming model and APIs of Sonnet, a neural network
  library with near universal adoption at DeepMind. It preserves Sonnet's
  `Module`-based programming model for state management while retaining access
  to JAX's function transformations.
- Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users
  have found Sonnet to be a productive programming model in TensorFlow; Haiku
  enables the same experience in JAX.

### Transitioning to Haiku is easy.

- By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
- Outside of new features (e.g. `hk.transform`), Haiku aims to match the API of
  Sonnet 2. Modules, methods, argument names, defaults, and initialization
  schemes should match.

### Haiku makes other aspects of JAX simpler.

- Haiku offers a trivial model for working with random numbers. Within a
  transformed function, `hk.next_rng_key()` returns a unique rng key.
- These unique keys are deterministically derived from an initial random key
  passed into the top-level transformed function, and are thus safe to use with
  JAX program transformations.

## Quickstart<a id="quickstart"></a>

Let's take a look at an example neural network, loss function, and training
loop. (For more examples, see our
[examples directory](https://github.com/deepmind/dm-haiku/tree/main/examples/).
The
[MNIST example](https://github.com/deepmind/dm-haiku/tree/main/examples/mnist.py)
is a good place to start.)

```python
import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_util.tree_map(update_rule, params, grads)
```

The core of Haiku is `hk.transform`. The `transform` function allows you to
write neural network functions that rely on parameters (here the weights of the
`Linear` layers) without requiring you to explicitly write the boilerplate
for initialising those parameters. `transform` does this by transforming the
function into a pair of functions that are _pure_ (as required by JAX) `init`
and `apply`.

### `init`

The `init` function, with signature `params = init(rng, ...)` (where `...` are
the arguments to the untransformed function), allows you to **collect** the
initial value of any parameters in the network. Haiku does this by running your
function, keeping track of any parameters requested through `hk.get_parameter`
(called by e.g. `hk.Linear`) and returning them to you.

The `params` object returned is a nested data structure of all the
parameters in your network, designed for you to inspect and manipulate. 
Concretely, it is a mapping of module name to module parameters, where a module
parameter is a mapping of parameter name to parameter value. For example:

```
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}
```

### `apply`

The `apply` function, with signature `result = apply(params, rng, ...)`, allows
you to **inject** parameter values into your function. Whenever
`hk.get_parameter` is called, the value returned will come from the `params` you
provide as input to `apply`:

```python
loss = loss_fn_t.apply(params, rng, images, labels)
```

Note that since the actual computation performed by our loss function doesn't
rely on random numbers, passing in a random number generator is unnecessary, so
we could also pass in `None` for the `rng` argument. (Note that if your
computation _does_ use random numbers, passing in `None` for `rng` will cause
an error to be raised.) In our example above, we ask Haiku to do this for us
automatically with:

```python
loss_fn_t = hk.without_apply_rng(loss_fn_t)
```

Since `apply` is a pure function we can pass it to `jax.grad` (or any of JAX's
other transforms):

```python
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
```

### Training

The training loop in this example is very simple. One detail to note is the use
of `jax.tree_util.tree_map` to apply the `sgd` function across all matching
entries in `params` and `grads`. The result has the same structure as the
previous `params` and can again be used with `apply`.


## Installation<a id="installation"></a>

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does
not list JAX as a dependency in `requirements.txt`.

First, follow [these instructions](https://github.com/google/jax#installation)
to install JAX with the relevant accelerator support.

Then, install Haiku using pip:

```bash
$ pip install git+https://github.com/deepmind/dm-haiku
```

Alternatively, you can install via PyPI:

```bash
$ pip install -U dm-haiku
```

Our examples rely on additional libraries (e.g. [bsuite](https://github.com/deepmind/bsuite)). You can install the full set of additional requirements using pip:

```bash
$ pip install -r examples/requirements.txt
```

## User manual<a id="user-manual"></a>

### Writing your own modules

In Haiku, all modules are a subclass of `hk.Module`. You can implement any
method you like (nothing is special-cased), but typically modules implement
`__init__` and `__call__`.

Let's work through implementing a linear layer:

```python
class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b
```

All modules have a name. When no `name` argument is passed to the module, its
name is inferred from the name of the Python class (for example `MyLinear`
becomes `my_linear`). Modules can have named parameters that are accessed
using `hk.get_parameter(param_name, ...)`. We use this API (rather than just
using object properties) so that we can convert your code into a pure function
using `hk.transform`.

When using modules you need to define functions and transform them into a pair
of pure functions using `hk.transform`. See our [quickstart](#quickstart) for
more details about the functions returned from `transform`:

```python
def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward.apply(params, None, x)
```

### Working with stochastic models

Some models may require random sampling as part of the computation.
For example, in variational autoencoders with the reparametrization trick,
a random sample from the standard normal distribution is needed. For dropout we
need a random mask to drop units from the input. The main hurdle in making this
work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated
with modules: `hk.next_rng_key()` (or `next_rng_keys()` for multiple keys):

```python
class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)
```

For a more complete look at working with stochastic models, please see our
[VAE example](https://github.com/deepmind/dm-haiku/tree/main/examples/vae.py).

**Note:** `hk.next_rng_key()` is not functionally pure which means you should
avoid using it alongside JAX transformations which are inside `hk.transform`.
For more information and possible workarounds, please consult the docs on
[Haiku transforms](https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html)
and available
[wrappers for JAX transforms inside Haiku networks](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-transforms).

### Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in
batch normalization a moving average of values encountered during training is
maintained.

In Haiku we provide a simple API for maintaining mutable state that is
associated with modules: `hk.set_state` and `hk.get_state`. When using these
functions you need to transform your function using `hk.transform_with_state`
since the signature of the returned pair of functions is different:

```python
def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)

# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)
```

If you forget to use `hk.transform_with_state` don't worry, we will print a
clear error pointing you to `hk.transform_with_state` rather than silently
dropping your state.

### Distributed training with `jax.pmap`

The pure functions returned from `hk.transform` (or `hk.transform_with_state`)
are fully compatible with `jax.pmap`. For more details on SPMD programming with
`jax.pmap`,
[look here](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).

One common use of `jax.pmap` with Haiku is for data-parallel training on many
accelerators, potentially across multiple hosts. With Haiku, that might look
like this:

```python
def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)
```

For a more complete look at distributed Haiku training, take a look at our
[ResNet-50 on ImageNet example](https://github.com/deepmind/dm-haiku/tree/main/examples/imagenet/).

## Citing Haiku<a id="citing-haiku"></a>

To cite this repository:

```
@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.10},
  year = {2020},
}
```

In this bibtex entry, the version number is intended to be from
[`haiku/__init__.py`](https://github.com/deepmind/dm-haiku/blob/main/haiku/__init__.py),
and the year corresponds to the project's open-source release.

[JAX]: https://github.com/google/jax
[Sonnet]: https://github.com/deepmind/sonnet
[Tensorflow]: https://github.com/tensorflow/tensorflow
[Flax]: https://github.com/google/flax
[Google DeepMind]: https://blog.google/technology/ai/april-ai-update/
[Google Brain]: https://research.google/teams/brain/

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/deepmind/dm-haiku",
    "name": "dm-haiku",
    "maintainer": "",
    "docs_url": null,
    "requires_python": "",
    "maintainer_email": "",
    "keywords": "",
    "author": "DeepMind",
    "author_email": "haiku-dev-os@google.com",
    "download_url": "https://files.pythonhosted.org/packages/1a/60/6f4ce478f96d9d6fe48b9fe9d9e4e45fee768dbdb3c13e11df8cde7849f5/dm-haiku-0.0.11.tar.gz",
    "platform": null,
    "description": "# Haiku: [Sonnet] for [JAX]\n\n[**Overview**](#overview)\n| [**Why Haiku?**](#why-haiku)\n| [**Quickstart**](#quickstart)\n| [**Installation**](#installation)\n| [**Examples**](https://github.com/deepmind/dm-haiku/tree/main/examples/)\n| [**User manual**](#user-manual)\n| [**Documentation**](https://dm-haiku.readthedocs.io/)\n| [**Citing Haiku**](#citing-haiku)\n\n![pytest](https://github.com/deepmind/dm-haiku/workflows/pytest/badge.svg)\n![docs](https://readthedocs.org/projects/dm-haiku/badge/?version=latest)\n![pypi](https://img.shields.io/pypi/v/dm-haiku)\n\n> [!IMPORTANT]\n> \ud83d\udce3 **As of July 2023 [Google DeepMind] recommends that new projects adopt\n> [Flax] (a neural network library originally developed by [Google Brain] and\n> now by [Google DeepMind]) instead of Haiku.** \ud83d\udce3\n>\n> At the time of writing [Flax] has superset of the features available in Haiku,\n> a [larger](https://github.com/google/flax/graphs/contributors) and\n> [more active](https://github.com/google/flax/activity) development team and\n> more adoption with users outside of Alphabet. [Flax] has\n> [more extensive documentation](https://flax.readthedocs.io/),\n> [examples](https://github.com/huggingface/transformers/tree/main/examples/flax)\n> and an [active community](https://huggingface.co/flax-community) creating end\n> to end examples.\n>\n> Haiku will remain best-effort supported, however the project will enter\n> [maintenance mode](https://en.wikipedia.org/wiki/Maintenance_mode), meaning\n> that development efforts will be focussed on bug fixes and compatibility with\n> new releases of JAX.\n>\n> New releases will be made to keep Haiku working with newer versions of Python\n> and [JAX], however we will not be adding (or accepting PRs for) new features.\n>\n> We have significant usage of Haiku internally at [Google DeepMind] and\n> currently plan to support Haiku in this mode indefinitely.\n\n## What is Haiku?\n\n> Haiku is a tool<br>\n> For building neural networks<br>\n> Think: \"[Sonnet] for [JAX]\"\n\nHaiku is a simple neural network library for [JAX] developed by some of the\nauthors of [Sonnet], a neural network library for [TensorFlow].\n\nDocumentation on Haiku can be found at https://dm-haiku.readthedocs.io/.\n\n**Disambiguation:** if you are looking for Haiku the operating system then\nplease see https://haiku-os.org/.\n\n## Overview<a id=\"overview\"></a>\n\n[JAX] is a numerical computing library that combines NumPy, automatic\ndifferentiation, and first-class GPU/TPU support.\n\nHaiku is a simple neural network library for JAX that enables users to use\nfamiliar **object-oriented programming models** while allowing full access to\nJAX's pure function transformations.\n\nHaiku provides two core tools: a module abstraction, `hk.Module`, and a simple\nfunction transformation, `hk.transform`.\n\n`hk.Module`s are Python objects that hold references to their own parameters,\nother modules, and methods that apply functions on user inputs.\n\n`hk.transform` turns functions that use these object-oriented, functionally\n\"impure\" modules into pure functions that can be used with `jax.jit`,\n`jax.grad`, `jax.pmap`, etc.\n\n## Why Haiku?<a id=\"why-haiku\"></a>\n\nThere are a number of neural network libraries for JAX. Why should you choose\nHaiku?\n\n### Haiku has been tested by researchers at DeepMind at scale.\n\n- DeepMind has reproduced a number of experiments in Haiku and JAX with relative\n  ease. These include large-scale results in image and language processing,\n  generative models, and reinforcement learning.\n\n### Haiku is a library, not a framework.\n\n- Haiku is designed to make specific things simpler: managing model parameters\n  and other model state.\n- Haiku can be expected to compose with other libraries and work well with the\n  rest of JAX.\n- Haiku otherwise is designed to get out of your way - it does not define custom\n  optimizers, checkpointing formats, or replication APIs.\n\n### Haiku does not reinvent the wheel.\n\n- Haiku builds on the programming model and APIs of Sonnet, a neural network\n  library with near universal adoption at DeepMind. It preserves Sonnet's\n  `Module`-based programming model for state management while retaining access\n  to JAX's function transformations.\n- Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users\n  have found Sonnet to be a productive programming model in TensorFlow; Haiku\n  enables the same experience in JAX.\n\n### Transitioning to Haiku is easy.\n\n- By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.\n- Outside of new features (e.g. `hk.transform`), Haiku aims to match the API of\n  Sonnet 2. Modules, methods, argument names, defaults, and initialization\n  schemes should match.\n\n### Haiku makes other aspects of JAX simpler.\n\n- Haiku offers a trivial model for working with random numbers. Within a\n  transformed function, `hk.next_rng_key()` returns a unique rng key.\n- These unique keys are deterministically derived from an initial random key\n  passed into the top-level transformed function, and are thus safe to use with\n  JAX program transformations.\n\n## Quickstart<a id=\"quickstart\"></a>\n\nLet's take a look at an example neural network, loss function, and training\nloop. (For more examples, see our\n[examples directory](https://github.com/deepmind/dm-haiku/tree/main/examples/).\nThe\n[MNIST example](https://github.com/deepmind/dm-haiku/tree/main/examples/mnist.py)\nis a good place to start.)\n\n```python\nimport haiku as hk\nimport jax.numpy as jnp\n\ndef softmax_cross_entropy(logits, labels):\n  one_hot = jax.nn.one_hot(labels, logits.shape[-1])\n  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)\n\ndef loss_fn(images, labels):\n  mlp = hk.Sequential([\n      hk.Linear(300), jax.nn.relu,\n      hk.Linear(100), jax.nn.relu,\n      hk.Linear(10),\n  ])\n  logits = mlp(images)\n  return jnp.mean(softmax_cross_entropy(logits, labels))\n\nloss_fn_t = hk.transform(loss_fn)\nloss_fn_t = hk.without_apply_rng(loss_fn_t)\n\nrng = jax.random.PRNGKey(42)\ndummy_images, dummy_labels = next(input_dataset)\nparams = loss_fn_t.init(rng, dummy_images, dummy_labels)\n\ndef update_rule(param, update):\n  return param - 0.01 * update\n\nfor images, labels in input_dataset:\n  grads = jax.grad(loss_fn_t.apply)(params, images, labels)\n  params = jax.tree_util.tree_map(update_rule, params, grads)\n```\n\nThe core of Haiku is `hk.transform`. The `transform` function allows you to\nwrite neural network functions that rely on parameters (here the weights of the\n`Linear` layers) without requiring you to explicitly write the boilerplate\nfor initialising those parameters. `transform` does this by transforming the\nfunction into a pair of functions that are _pure_ (as required by JAX) `init`\nand `apply`.\n\n### `init`\n\nThe `init` function, with signature `params = init(rng, ...)` (where `...` are\nthe arguments to the untransformed function), allows you to **collect** the\ninitial value of any parameters in the network. Haiku does this by running your\nfunction, keeping track of any parameters requested through `hk.get_parameter`\n(called by e.g. `hk.Linear`) and returning them to you.\n\nThe `params` object returned is a nested data structure of all the\nparameters in your network, designed for you to inspect and manipulate. \nConcretely, it is a mapping of module name to module parameters, where a module\nparameter is a mapping of parameter name to parameter value. For example:\n\n```\n{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),\n            'w': ndarray(..., shape=(28, 300), dtype=float32)},\n 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),\n              'w': ndarray(..., shape=(1000, 100), dtype=float32)},\n 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),\n              'w': ndarray(..., shape=(100, 10), dtype=float32)}}\n```\n\n### `apply`\n\nThe `apply` function, with signature `result = apply(params, rng, ...)`, allows\nyou to **inject** parameter values into your function. Whenever\n`hk.get_parameter` is called, the value returned will come from the `params` you\nprovide as input to `apply`:\n\n```python\nloss = loss_fn_t.apply(params, rng, images, labels)\n```\n\nNote that since the actual computation performed by our loss function doesn't\nrely on random numbers, passing in a random number generator is unnecessary, so\nwe could also pass in `None` for the `rng` argument. (Note that if your\ncomputation _does_ use random numbers, passing in `None` for `rng` will cause\nan error to be raised.) In our example above, we ask Haiku to do this for us\nautomatically with:\n\n```python\nloss_fn_t = hk.without_apply_rng(loss_fn_t)\n```\n\nSince `apply` is a pure function we can pass it to `jax.grad` (or any of JAX's\nother transforms):\n\n```python\ngrads = jax.grad(loss_fn_t.apply)(params, images, labels)\n```\n\n### Training\n\nThe training loop in this example is very simple. One detail to note is the use\nof `jax.tree_util.tree_map` to apply the `sgd` function across all matching\nentries in `params` and `grads`. The result has the same structure as the\nprevious `params` and can again be used with `apply`.\n\n\n## Installation<a id=\"installation\"></a>\n\nHaiku is written in pure Python, but depends on C++ code via JAX.\n\nBecause JAX installation is different depending on your CUDA version, Haiku does\nnot list JAX as a dependency in `requirements.txt`.\n\nFirst, follow [these instructions](https://github.com/google/jax#installation)\nto install JAX with the relevant accelerator support.\n\nThen, install Haiku using pip:\n\n```bash\n$ pip install git+https://github.com/deepmind/dm-haiku\n```\n\nAlternatively, you can install via PyPI:\n\n```bash\n$ pip install -U dm-haiku\n```\n\nOur examples rely on additional libraries (e.g. [bsuite](https://github.com/deepmind/bsuite)). You can install the full set of additional requirements using pip:\n\n```bash\n$ pip install -r examples/requirements.txt\n```\n\n## User manual<a id=\"user-manual\"></a>\n\n### Writing your own modules\n\nIn Haiku, all modules are a subclass of `hk.Module`. You can implement any\nmethod you like (nothing is special-cased), but typically modules implement\n`__init__` and `__call__`.\n\nLet's work through implementing a linear layer:\n\n```python\nclass MyLinear(hk.Module):\n\n  def __init__(self, output_size, name=None):\n    super().__init__(name=name)\n    self.output_size = output_size\n\n  def __call__(self, x):\n    j, k = x.shape[-1], self.output_size\n    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))\n    w = hk.get_parameter(\"w\", shape=[j, k], dtype=x.dtype, init=w_init)\n    b = hk.get_parameter(\"b\", shape=[k], dtype=x.dtype, init=jnp.zeros)\n    return jnp.dot(x, w) + b\n```\n\nAll modules have a name. When no `name` argument is passed to the module, its\nname is inferred from the name of the Python class (for example `MyLinear`\nbecomes `my_linear`). Modules can have named parameters that are accessed\nusing `hk.get_parameter(param_name, ...)`. We use this API (rather than just\nusing object properties) so that we can convert your code into a pure function\nusing `hk.transform`.\n\nWhen using modules you need to define functions and transform them into a pair\nof pure functions using `hk.transform`. See our [quickstart](#quickstart) for\nmore details about the functions returned from `transform`:\n\n```python\ndef forward_fn(x):\n  model = MyLinear(10)\n  return model(x)\n\n# Turn `forward_fn` into an object with `init` and `apply` methods. By default,\n# the `apply` will require an rng (which can be None), to be used with\n# `hk.next_rng_key`.\nforward = hk.transform(forward_fn)\n\nx = jnp.ones([1, 1])\n\n# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial\n# parameter values. Haiku requires you pass a RNG key to `init`, since parameters\n# are typically initialized randomly:\nkey = hk.PRNGSequence(42)\nparams = forward.init(next(key), x)\n\n# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter\n# values from the `params` that are passed as the first argument.  Note that\n# models transformed using `hk.transform(f)` must be called with an additional\n# `rng` argument: `forward.apply(params, rng, x)`. Use\n# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.\ny = forward.apply(params, None, x)\n```\n\n### Working with stochastic models\n\nSome models may require random sampling as part of the computation.\nFor example, in variational autoencoders with the reparametrization trick,\na random sample from the standard normal distribution is needed. For dropout we\nneed a random mask to drop units from the input. The main hurdle in making this\nwork with JAX is in management of PRNG keys.\n\nIn Haiku we provide a simple API for maintaining a PRNG key sequence associated\nwith modules: `hk.next_rng_key()` (or `next_rng_keys()` for multiple keys):\n\n```python\nclass MyDropout(hk.Module):\n\n  def __init__(self, rate=0.5, name=None):\n    super().__init__(name=name)\n    self.rate = rate\n\n  def __call__(self, x):\n    key = hk.next_rng_key()\n    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)\n    return x * p / (1.0 - self.rate)\n\nforward = hk.transform(lambda x: MyDropout()(x))\n\nkey1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)\nparams = forward.init(key1, x)\nprediction = forward.apply(params, key2, x)\n```\n\nFor a more complete look at working with stochastic models, please see our\n[VAE example](https://github.com/deepmind/dm-haiku/tree/main/examples/vae.py).\n\n**Note:** `hk.next_rng_key()` is not functionally pure which means you should\navoid using it alongside JAX transformations which are inside `hk.transform`.\nFor more information and possible workarounds, please consult the docs on\n[Haiku transforms](https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html)\nand available\n[wrappers for JAX transforms inside Haiku networks](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-transforms).\n\n### Working with non-trainable state\n\nSome models may want to maintain some internal, mutable state. For example, in\nbatch normalization a moving average of values encountered during training is\nmaintained.\n\nIn Haiku we provide a simple API for maintaining mutable state that is\nassociated with modules: `hk.set_state` and `hk.get_state`. When using these\nfunctions you need to transform your function using `hk.transform_with_state`\nsince the signature of the returned pair of functions is different:\n\n```python\ndef forward(x, is_training):\n  net = hk.nets.ResNet50(1000)\n  return net(x, is_training)\n\nforward = hk.transform_with_state(forward)\n\n# The `init` function now returns parameters **and** state. State contains\n# anything that was created using `hk.set_state`. The structure is the same as\n# params (e.g. it is a per-module mapping of named values).\nparams, state = forward.init(rng, x, is_training=True)\n\n# The apply function now takes both params **and** state. Additionally it will\n# return updated values for state. In the resnet example this will be the\n# updated values for moving averages used in the batch norm layers.\nlogits, state = forward.apply(params, state, rng, x, is_training=True)\n```\n\nIf you forget to use `hk.transform_with_state` don't worry, we will print a\nclear error pointing you to `hk.transform_with_state` rather than silently\ndropping your state.\n\n### Distributed training with `jax.pmap`\n\nThe pure functions returned from `hk.transform` (or `hk.transform_with_state`)\nare fully compatible with `jax.pmap`. For more details on SPMD programming with\n`jax.pmap`,\n[look here](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).\n\nOne common use of `jax.pmap` with Haiku is for data-parallel training on many\naccelerators, potentially across multiple hosts. With Haiku, that might look\nlike this:\n\n```python\ndef loss_fn(inputs, labels):\n  logits = hk.nets.MLP([8, 4, 2])(x)\n  return jnp.mean(softmax_cross_entropy(logits, labels))\n\nloss_fn_t = hk.transform(loss_fn)\nloss_fn_t = hk.without_apply_rng(loss_fn_t)\n\n# Initialize the model on a single device.\nrng = jax.random.PRNGKey(428)\nsample_image, sample_label = next(input_dataset)\nparams = loss_fn_t.init(rng, sample_image, sample_label)\n\n# Replicate params onto all devices.\nnum_devices = jax.local_device_count()\nparams = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)\n\ndef make_superbatch():\n  \"\"\"Constructs a superbatch, i.e. one batch of data per device.\"\"\"\n  # Get N batches, then split into list-of-images and list-of-labels.\n  superbatch = [next(input_dataset) for _ in range(num_devices)]\n  superbatch_images, superbatch_labels = zip(*superbatch)\n  # Stack the superbatches to be one array with a leading dimension, rather than\n  # a python list. This is what `jax.pmap` expects as input.\n  superbatch_images = np.stack(superbatch_images)\n  superbatch_labels = np.stack(superbatch_labels)\n  return superbatch_images, superbatch_labels\n\ndef update(params, inputs, labels, axis_name='i'):\n  \"\"\"Updates params based on performance on inputs and labels.\"\"\"\n  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)\n  # Take the mean of the gradients across all data-parallel replicas.\n  grads = jax.lax.pmean(grads, axis_name)\n  # Update parameters using SGD or Adam or ...\n  new_params = my_update_rule(params, grads)\n  return new_params\n\n# Run several training updates.\nfor _ in range(10):\n  superbatch_images, superbatch_labels = make_superbatch()\n  params = jax.pmap(update, axis_name='i')(params, superbatch_images,\n                                           superbatch_labels)\n```\n\nFor a more complete look at distributed Haiku training, take a look at our\n[ResNet-50 on ImageNet example](https://github.com/deepmind/dm-haiku/tree/main/examples/imagenet/).\n\n## Citing Haiku<a id=\"citing-haiku\"></a>\n\nTo cite this repository:\n\n```\n@software{haiku2020github,\n  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},\n  title = {{H}aiku: {S}onnet for {JAX}},\n  url = {http://github.com/deepmind/dm-haiku},\n  version = {0.0.10},\n  year = {2020},\n}\n```\n\nIn this bibtex entry, the version number is intended to be from\n[`haiku/__init__.py`](https://github.com/deepmind/dm-haiku/blob/main/haiku/__init__.py),\nand the year corresponds to the project's open-source release.\n\n[JAX]: https://github.com/google/jax\n[Sonnet]: https://github.com/deepmind/sonnet\n[Tensorflow]: https://github.com/tensorflow/tensorflow\n[Flax]: https://github.com/google/flax\n[Google DeepMind]: https://blog.google/technology/ai/april-ai-update/\n[Google Brain]: https://research.google/teams/brain/\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Haiku is a library for building neural networks in JAX.",
    "version": "0.0.11",
    "project_urls": {
        "Homepage": "https://github.com/deepmind/dm-haiku"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c89ccb845125e88e392699ad619bab338100a51f2d23090ac0bc4fa1df318a2d",
                "md5": "b64388094f77f26b9be8627a5a5d55f3",
                "sha256": "4cac556a9d0e41758abda66bef5ff9dbb36e409c8cfc2b6f20247bc7d39ae45b"
            },
            "downloads": -1,
            "filename": "dm_haiku-0.0.11-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "b64388094f77f26b9be8627a5a5d55f3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 370959,
            "upload_time": "2023-11-10T14:04:37",
            "upload_time_iso_8601": "2023-11-10T14:04:37.451893Z",
            "url": "https://files.pythonhosted.org/packages/c8/9c/cb845125e88e392699ad619bab338100a51f2d23090ac0bc4fa1df318a2d/dm_haiku-0.0.11-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1a606f4ce478f96d9d6fe48b9fe9d9e4e45fee768dbdb3c13e11df8cde7849f5",
                "md5": "56a44750cad7a3d9265b53e3bc9b3134",
                "sha256": "c420a90c6a76c1d941996698840089df0d352806312eaf7b737486f6c6a32ef2"
            },
            "downloads": -1,
            "filename": "dm-haiku-0.0.11.tar.gz",
            "has_sig": false,
            "md5_digest": "56a44750cad7a3d9265b53e3bc9b3134",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 260101,
            "upload_time": "2023-11-10T14:04:38",
            "upload_time_iso_8601": "2023-11-10T14:04:38.927080Z",
            "url": "https://files.pythonhosted.org/packages/1a/60/6f4ce478f96d9d6fe48b9fe9d9e4e45fee768dbdb3c13e11df8cde7849f5/dm-haiku-0.0.11.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-11-10 14:04:38",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "deepmind",
    "github_project": "dm-haiku",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "dm-haiku"
}
        
Elapsed time: 0.13934s