# Mixed precision training in [JAX]
![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg)
![PyPI version](https://img.shields.io/pypi/v/jmp)
[**Installation**](#installation)
| [**Examples**](#examples)
| [**Policies**](#policies)
| [**Loss scaling**](#loss-scaling)
| [**Citing JMP**](#citing-jmp)
| [**References**](#references)
Mixed precision training [[0]] is a technique that mixes the use of full and
half precision floating point numbers during training to reduce the memory
bandwidth requirements and improve the computational efficiency of a given
model.
This library implements support for mixed precision training in [JAX] by providing
two key abstractions (mixed precision "policies" and loss scaling). Neural
network libraries (such as [Haiku]) can integrate with `jmp` and provide
"Automatic Mixed Precision (AMP)" support (automating or simplifying applying
policies to modules).
All code examples below assume the following:
```python
import jax
import jax.numpy as jnp
import jmp
half = jnp.float16 # On TPU this should be jnp.bfloat16.
full = jnp.float32
```
## Installation
JMP is written in pure Python, but depends on C++ code via JAX and NumPy.
Because JAX installation is different depending on your CUDA version, JMP does
not list JAX as a dependency in `requirements.txt`.
First, follow [these instructions](https://github.com/google/jax#installation)
to install JAX with the relevant accelerator support.
Then, install JMP using pip:
```bash
$ pip install git+https://github.com/deepmind/jmp
```
## Examples
You can find a
[fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet)
which shows how to use mixed f32/f16 precision to halve training time on GPU and
mixed f32/bf16 to reduce training time on TPU by a third.
## Policies
A mixed precision policy encapsulates the configuration in a mixed precision
experiment.
```python
# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
param_dtype=full,
output_dtype=half)
```
The policy object can be used to cast pytrees:
```python
def layer(params, x):
params, x = my_policy.cast_to_compute((params, x))
w, b = params
y = x @ w + b
return my_policy.cast_to_output(y)
params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half
```
You can replace the output type of a given policy:
```python
my_policy = my_policy.with_output_dtype(full)
```
You can also define a policy via a string, which may be useful for specifying a
policy as a command-line argument or as a hyperparameter to your experiment:
```python
my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16") # Everything in f16.
half = jmp.get_policy("half") # Everything in half (f16 or bf16).
```
## Loss scaling
When training with reduced precision, consider whether gradients will need to be
shifted into the representable range of the format that you are using. This is
particularly important when training with `float16` and less important for
`bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details.
The easiest way to shift gradients is with loss scaling, which scales your loss
and gradients by `S` and `1/S` respectively.
```python
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: jmp.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
params = apply_optimizer(params, grads)
return params
loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
params = train_step(params, loss_scale, ...)
```
The appropriate value for `S` depends on your model, loss, batch size and
potentially other factors. You can determine this with trial and error. As a
rule of thumb you want the largest value of `S` that does not introduce overflow
during backprop. NVIDIA [[1]] recommend computing statistics about the gradients
of your model (in full precision) and picking `S` such that its product with the
maximum norm of your gradients is below `65,504`.
We provide a dynamic loss scale, which adjusts the loss scale periodically
during training to find the largest value for `S` that produces finite
gradients. This is more convenient and robust compared with picking a static
loss scale, but has a small performance impact (between 1 and 5%).
```python
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: jmp.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
# You definitely want to skip non-finite updates with the dynamic loss scale,
# but you might also want to consider skipping them when using a static loss
# scale if you experience NaN's when training.
skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
if skip_nonfinite_updates:
grads_finite = jmp.all_finite(grads)
# Adjust our loss scale depending on whether gradients were finite. The
# loss scale will be periodically increased if gradients remain finite and
# will be decreased if not.
loss_scale = loss_scale.adjust(grads_finite)
# Only apply our optimizer if grads are finite, if any element of any
# gradient is non-finite the whole update is discarded.
params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
else:
# With static or no loss scaling just apply our optimizer.
params = apply_optimizer(params, grads)
# Since our loss scale is dynamic we need to return the new value from
# each step. All loss scales are `PyTree`s.
return params, loss_scale
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
params, loss_scale = train_step(params, loss_scale, ...)
```
In general using a static loss scale should offer the best speed, but we have
optimized dynamic loss scaling to make it competitive. We recommend you start
with dynamic loss scaling and move to static loss scaling if performance is an
issue.
We finally offer a no-op loss scale which you can use as a drop in replacement.
It does nothing (apart from implement the `jmp.LossScale` API):
```python
loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1
```
## Citing JMP
This repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research),
to cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt).
## References
[[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich
Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh
Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740
https://arxiv.org/abs/1710.03740.
[[1]] "Training With Mixed Precision :: NVIDIA Deep Learning Performance
Documentation". Docs.Nvidia.Com, 2020,
https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.
[0]: https://arxiv.org/abs/1710.03740
[1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html
[Haiku]: https://github.com/deepmind/dm-haiku
[JAX]: https://github.com/google/jax
Raw data
{
"_id": null,
"home_page": "https://github.com/deepmind/jmp",
"name": "jmp",
"maintainer": "",
"docs_url": null,
"requires_python": "",
"maintainer_email": "",
"keywords": "",
"author": "DeepMind",
"author_email": "jmp-dev-os@google.com",
"download_url": "https://files.pythonhosted.org/packages/ab/b0/e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62/jmp-0.0.4.tar.gz",
"platform": null,
"description": "# Mixed precision training in [JAX]\n\n![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg)\n![PyPI version](https://img.shields.io/pypi/v/jmp)\n\n[**Installation**](#installation)\n| [**Examples**](#examples)\n| [**Policies**](#policies)\n| [**Loss scaling**](#loss-scaling)\n| [**Citing JMP**](#citing-jmp)\n| [**References**](#references)\n\nMixed precision training [[0]] is a technique that mixes the use of full and\nhalf precision floating point numbers during training to reduce the memory\nbandwidth requirements and improve the computational efficiency of a given\nmodel.\n\nThis library implements support for mixed precision training in [JAX] by providing\ntwo key abstractions (mixed precision \"policies\" and loss scaling). Neural\nnetwork libraries (such as [Haiku]) can integrate with `jmp` and provide\n\"Automatic Mixed Precision (AMP)\" support (automating or simplifying applying\npolicies to modules).\n\nAll code examples below assume the following:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport jmp\n\nhalf = jnp.float16 # On TPU this should be jnp.bfloat16.\nfull = jnp.float32\n```\n\n## Installation\n\nJMP is written in pure Python, but depends on C++ code via JAX and NumPy.\n\nBecause JAX installation is different depending on your CUDA version, JMP does\nnot list JAX as a dependency in `requirements.txt`.\n\nFirst, follow [these instructions](https://github.com/google/jax#installation)\nto install JAX with the relevant accelerator support.\n\nThen, install JMP using pip:\n\n```bash\n$ pip install git+https://github.com/deepmind/jmp\n```\n\n## Examples\n\nYou can find a\n[fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet)\nwhich shows how to use mixed f32/f16 precision to halve training time on GPU and\nmixed f32/bf16 to reduce training time on TPU by a third.\n\n## Policies\n\nA mixed precision policy encapsulates the configuration in a mixed precision\nexperiment.\n\n```python\n# Our policy specifies that we will store parameters in full precision but will\n# compute and return output in half precision.\nmy_policy = jmp.Policy(compute_dtype=half,\n param_dtype=full,\n output_dtype=half)\n```\n\nThe policy object can be used to cast pytrees:\n\n```python\ndef layer(params, x):\n params, x = my_policy.cast_to_compute((params, x))\n w, b = params\n y = x @ w + b\n return my_policy.cast_to_output(y)\n\nparams = {\"w\": jnp.ones([], dtype=my_policy.param_dtype)}\ny = layer(params, x)\nassert y.dtype == half\n```\n\nYou can replace the output type of a given policy:\n\n```python\nmy_policy = my_policy.with_output_dtype(full)\n```\n\nYou can also define a policy via a string, which may be useful for specifying a\npolicy as a command-line argument or as a hyperparameter to your experiment:\n\n```python\nmy_policy = jmp.get_policy(\"params=float32,compute=float16,output=float32\")\nfloat16 = jmp.get_policy(\"float16\") # Everything in f16.\nhalf = jmp.get_policy(\"half\") # Everything in half (f16 or bf16).\n```\n\n## Loss scaling\n\nWhen training with reduced precision, consider whether gradients will need to be\nshifted into the representable range of the format that you are using. This is\nparticularly important when training with `float16` and less important for\n`bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details.\n\nThe easiest way to shift gradients is with loss scaling, which scales your loss\nand gradients by `S` and `1/S` respectively.\n\n```python\ndef my_loss_fn(params, loss_scale: jmp.LossScale, ...):\n loss = ...\n # You should apply regularization etc before scaling.\n loss = loss_scale.scale(loss)\n return loss\n\ndef train_step(params, loss_scale: jmp.LossScale, ...):\n grads = jax.grad(my_loss_fn)(...)\n grads = loss_scale.unscale(grads)\n # You should put gradient clipping etc after unscaling.\n params = apply_optimizer(params, grads)\n return params\n\nloss_scale = jmp.StaticLossScale(2 ** 15)\nfor _ in range(num_steps):\n params = train_step(params, loss_scale, ...)\n```\n\nThe appropriate value for `S` depends on your model, loss, batch size and\npotentially other factors. You can determine this with trial and error. As a\nrule of thumb you want the largest value of `S` that does not introduce overflow\nduring backprop. NVIDIA [[1]] recommend computing statistics about the gradients\nof your model (in full precision) and picking `S` such that its product with the\nmaximum norm of your gradients is below `65,504`.\n\nWe provide a dynamic loss scale, which adjusts the loss scale periodically\nduring training to find the largest value for `S` that produces finite\ngradients. This is more convenient and robust compared with picking a static\nloss scale, but has a small performance impact (between 1 and 5%).\n\n```python\ndef my_loss_fn(params, loss_scale: jmp.LossScale, ...):\n loss = ...\n # You should apply regularization etc before scaling.\n loss = loss_scale.scale(loss)\n return loss\n\ndef train_step(params, loss_scale: jmp.LossScale, ...):\n grads = jax.grad(my_loss_fn)(...)\n grads = loss_scale.unscale(grads)\n # You should put gradient clipping etc after unscaling.\n\n # You definitely want to skip non-finite updates with the dynamic loss scale,\n # but you might also want to consider skipping them when using a static loss\n # scale if you experience NaN's when training.\n skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)\n\n if skip_nonfinite_updates:\n grads_finite = jmp.all_finite(grads)\n # Adjust our loss scale depending on whether gradients were finite. The\n # loss scale will be periodically increased if gradients remain finite and\n # will be decreased if not.\n loss_scale = loss_scale.adjust(grads_finite)\n # Only apply our optimizer if grads are finite, if any element of any\n # gradient is non-finite the whole update is discarded.\n params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)\n else:\n # With static or no loss scaling just apply our optimizer.\n params = apply_optimizer(params, grads)\n\n # Since our loss scale is dynamic we need to return the new value from\n # each step. All loss scales are `PyTree`s.\n return params, loss_scale\n\nloss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))\nfor _ in range(num_steps):\n params, loss_scale = train_step(params, loss_scale, ...)\n```\n\nIn general using a static loss scale should offer the best speed, but we have\noptimized dynamic loss scaling to make it competitive. We recommend you start\nwith dynamic loss scaling and move to static loss scaling if performance is an\nissue.\n\nWe finally offer a no-op loss scale which you can use as a drop in replacement.\nIt does nothing (apart from implement the `jmp.LossScale` API):\n\n```python\nloss_scale = jmp.NoOpLossScale()\nassert loss is loss_scale.scale(loss)\nassert grads is loss_scale.unscale(grads)\nassert loss_scale is loss_scale.adjust(grads_finite)\nassert loss_scale.loss_scale == 1\n```\n\n## Citing JMP\n\nThis repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research),\nto cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt).\n\n## References\n\n[[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich\nElsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh\nVenkatesh, Hao Wu: \"Mixed Precision Training\", 2017; arXiv:1710.03740\nhttps://arxiv.org/abs/1710.03740.\n\n[[1]] \"Training With Mixed Precision :: NVIDIA Deep Learning Performance\nDocumentation\". Docs.Nvidia.Com, 2020,\nhttps://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.\n\n[0]: https://arxiv.org/abs/1710.03740\n[1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html\n[Haiku]: https://github.com/deepmind/dm-haiku\n[JAX]: https://github.com/google/jax\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": "JMP is a Mixed Precision library for JAX.",
"version": "0.0.4",
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "27e5cce82de2831e5aff9332d8d624bb57188f1b2af6ccf6979caf898a8a4348",
"md5": "81b0120606effa836124b8f2279e455d",
"sha256": "6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d"
},
"downloads": -1,
"filename": "jmp-0.0.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "81b0120606effa836124b8f2279e455d",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 18274,
"upload_time": "2023-01-30T12:47:11",
"upload_time_iso_8601": "2023-01-30T12:47:11.931457Z",
"url": "https://files.pythonhosted.org/packages/27/e5/cce82de2831e5aff9332d8d624bb57188f1b2af6ccf6979caf898a8a4348/jmp-0.0.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "abb0e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62",
"md5": "f524edec0db07383ee633aa77180abe8",
"sha256": "5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730"
},
"downloads": -1,
"filename": "jmp-0.0.4.tar.gz",
"has_sig": false,
"md5_digest": "f524edec0db07383ee633aa77180abe8",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 18582,
"upload_time": "2023-01-30T12:47:13",
"upload_time_iso_8601": "2023-01-30T12:47:13.634030Z",
"url": "https://files.pythonhosted.org/packages/ab/b0/e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62/jmp-0.0.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-01-30 12:47:13",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "deepmind",
"github_project": "jmp",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "jmp"
}