# mpx: Mixed Precision Training for JAX
## Installation
Before installing the library, please make sure that you installed JAX for your given hardware.
```console
pip install mixed-precision-for-JAX
```
## Documentation
For basic usage, this README should give you everything you need to know.
For deeper insights, you can read the [documentation](https://data-science-in-mechanical-engineering.github.io/mixed_precision_for_JAX/) (https://data-science-in-mechanical-engineering.github.io/mixed_precision_for_JAX/) and our [paper](https://www.arxiv.org/pdf/2507.03312) (https://www.arxiv.org/pdf/2507.03312).
## Introduction
This repository offers a tool for training JAX models using mixed precision, called **mpx**. It builds upon [JMP](https://github.com/google-deepmind/jmp)—another mixed precision library for JAX—but extends its capabilities.
JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). mpx overcomes these limitations, by leveraging Equinox's flexibility to work with any PyTree.
## Basics of Mixed Precision Training
This section summarizes the original Mixed Precision method from https://developer.nvidia.com/automatic-mixed-precision and https://arxiv.org/pdf/1710.03740.
Mixed Precision training involves performing most of the computations in the forward and backward passes of a neural network using 16-bit floating-point numbers.
This approach reduces GPU memory usage by roughly half compared to full precision training, allowing for larger batch sizes or the use of fewer TPUs/GPUs. Additionally, mixed precision can speed up training by decreasing memory access times and utilizing specialized half-precision tensor cores on modern hardware (if available).
One of the key factors when successfully applying Mixed Precision training is loss scaling. Due to the decreased resolution of float16, small gradients are cast to zero, decreasing training performance. The loss scaling scales the loss by a factor > 1, and as a result the gradients during gradient calculation. Afterwards, the gradients are cast to float32 and then divided by the factor to obtain the original gradient. A standard optimizer then uses the gradient to calculate the model update. The scaling can be chosen automatically with a simple heuristic. If the scaled gradients exceed the range of float16 (i.e., they are inf), we reduce the scaling and do not update the model. If the scaled gradients to not exceed the range of float16 for a longer time, we increase the scaling.
Mixed Precision Training hence has the following steps:
1. Initialize the Model and Optimizer using Full Precision.
2. Get a Batch from the dataloader.
3. Cast the batch and model for half precision (e.g., float16 or bfloat16).
4. Do the forward pass in halfprecision, except critical operations.
5. Scale the loss.
6. Calculate the gradient of the scaled loss with respect to the weights.
8. Cast weights to float32 and divide by the scaling value.
9. If gradients are infinite, decrease scaling, else, increase scaling if in every n-th epoch.
10. If gradients are finit do optimizer update, continue with 2.
`mpx` provides important functions for steps 3--9. However, it does not provide a Keras/PyTorch Lightning/Kauldron-like functionality, where you just pass model, loss and optimizer and call run. This is done on purpose to not hurt the low-level approach of JAX and allow users to write their training pipeline like they prefer.
## Main Features
`mpx` provides a comprehensive set of tools for mixed precision training in JAX.
The main goal was to keep the library as flexible and as close to `equinox` as possible.
As a result, to update a training pipeline to work with mixed precision, one just have to:
- Update the gradient calculations from `eqx.filter_grad/filter_value_and_grad` to `mpx.filter_grad/filter_value_and_grad`.
- Do the `optax` optimizer call via `mpx.optimizer_update`.
- Force critical operations like sum, mean and softmax to full precision using `mpx.force_full_precision`.
Here are the key components:
### Data Type Management
- `set_half_precision_datatype(dtype)`: Configure whether to use `float16` or `bfloat16` for half precision training
- `half_precision_datatype()`: Get the currently configured half precision data type
### Casting Functions
- `cast_to_half_precision(x: PyTree)`: Cast all JAX arrays in a PyTree to the configured half precision type
- `cast_to_full_precision(x: PyTree)`: Cast all JAX arrays in a PyTree to `float32`
- `cast_to_float16(x: PyTree)`: Cast all JAX arrays in a PyTree to `float16`
- `cast_to_bfloat16(x: PyTree)`: Cast all JAX arrays in a PyTree to `bfloat16`
- `cast_to_float32(x: PyTree)`: Cast all JAX arrays in a PyTree to `float32`
### Precision Control
- `force_full_precision`: A decorator that ensures a function performs all calculations in `float32`. This is essential for maintaining numerical stability in operations like mean, sum, and softmax. Currently, this has to be done by hand, i.e., **`mpx` does not identify critical operations and forces the to full precision, like AMP of PyTorch**. This unfortunately also means that provided neural network functions from equinox or flax that include these critical operations (e.g., equinox.nn.MultiheadAttention) must be rewritten by hand (please refer to our example). In a future release, we plan to provide a library that includes typical neural network functions, like Attention, that are ready for mixed precision.
### Loss Scaling
- `DynamicLossScaling`: A class that manages dynamic loss scaling to prevent underflow in half precision training. It is syntactically equivalent to `jmp.DynamicLossScaling`, however it can scale arbitrary PyTrees.
- `scale(x)`: Scale a value by the current loss scaling factor
- `unscale(x)`: Remove the loss scaling factor from a value
- `adjust(grads_finite)`: Update the loss scaling factor based on gradient stability
These functions are just for your information. They are internally used, however these might be interesting for non-standard implementations.
- `scaled(func, scaling)`: Decorator that applies loss scaling to a function's output
- `all_finite(tree)`: Check if all values in a PyTree are finite (not NaN or Inf)
### Gradient Computation
`mpx` provides function decorators for gradient calculations that summarize steps 3--9 in one function call. They have the same meaning and syntax as the corresponding decorators of `equinox`. This means, for an existing training pipeline, one can replace the calls of `equinox.filter_grad/filter_value_and_grad` with `mpx.filter_grad/filter_value_and_grad`
- `filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False, use_mixed_precision=True)`: Transformation that computes the gradient of func with respect to its first argument using mixed precision with scaling, similar to `equinox.filter_grad`. The transformed function then works as follows:
1. If `use_mixed_precision` is True:
- Casts all input arguments to half precision (float16/bfloat16)
- Scales the function's output by `scaling`
2. Computes gradients using `equinox.filter_grad`
3. If `use_mixed_precision` is True:
- Casts gradients back to full precision (float32)
- Checks if gradients are finite
- Updates `scaling` based on whether the gradients are inf or not.
- Unscales the gradients by dividing with `scaling`
4. Returns a tuple containing:
- The updated `scaling` object
- A boolean indicating if gradients are finite (needed for optimized step see below)
- The computed gradients
- Auxiliary values (if `has_aux=True`)
- `filter_value_and_grad(func, scaling)`: Decorator that works like `filter_grad`, except that it also returns the value.
The gradient transformations might return gradients that are infinite. In this case, the pipeline needs to skip the model update. For this, `mpx` provides the following function:
- `optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool)`: Apply optimizer updates only when gradients are finite. Works with arbitrary `optax` optimizers.
## Example
The following provides a small example, training a vision transformer on Cifar100 presenting all the important features of `mpx`. For details, please visit examples/train_vit.py.
This example will not go into the details for the neural network part, but just the `mpx` relevant parts.
The example was tested on an RTX4070, the training crashes with a batch size of 256 without mixed precision. With mixed precision, the training runs, demonstrating that mixed precision training via `mpx` effectively reduces the memory used on the GPU. The training speed itself does not change dramatically as the RTX4070 does not have a higher throughput for half precision operations.
### Installation and Execution of the Example
First install JAX for your hardware.
Then, install all dependencies via
```bash
pip install -r examples/requirements.txt
```
Then you can run the example via. ATTENTION: The script downloads Cifar100.
```bash
python -m examples.train_vit
```
### Explanation
The loss scaling has to be initialized during the instantiation of the datasets, models etc. Typically, the initial value is set to the maximum value of `float16`.
```python
loss_scaling = mpx.DynamicLossScaling(loss_scaling=mpx.FLOAT16_MAX,
min_loss_scaling=jnp.ones((), dtype=jnp.float32),
period=2000)
```
The loss_scaling object then must be passed to the training pipeline.
The most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and afterwards call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function, because `loss_scaling` is updated.
```python
@eqx.filter_jit
def make_step(model: eqx.Module,
optimizer: any,
optimizer_state: PyTree,
batch: dict,
batch_sharding: jax.sharding.NamedSharding,
replicated_sharding: jax.sharding.NamedSharding,
loss_scaling: mpx.DynamicLossScaling,
train_mixed_precicion: bool,
weight_regularization: Float,
key: PRNGKeyArray
) -> tuple[eqx.Module, PyTree, Float, PRNGKeyArray]:
batch = eqx.filter_shard(batch, batch_sharding)
model = eqx.filter_shard(model, replicated_sharding)
optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)
if train_mixed_precicion:
# this is the critical part
(loss_value, _), loss_scaling, grads_finite, grads = mpx.filter_value_and_grad(batched_loss_acc_wrapper, scaling=loss_scaling, has_aux=True)(
model, batch, batch_sharding, replicated_sharding, key, weight_regularization)
model, optimizer_state = mpx.optimizer_update(model, optimizer, optimizer_state, grads,grads_finite)
else:
(loss_value, _), grads = eqx.filter_value_and_grad(batched_loss_acc_wrapper, has_aux=True)(
model, batch, batch_sharding, replicated_sharding, key)
# optimizer step
updates, optimizer_state = optimizer.update(
grads, optimizer_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
model = eqx.filter_shard(model, replicated_sharding)
optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)
loss_scaling = eqx.filter_shard(loss_scaling, replicated_sharding)
# return loss_scaling as it is changed
return model, optimizer_state, loss_scaling, loss_value
```
Through the transformation via `mpx.filter_value_and_grad`, we can write our loss function as we normally do when using JAX/Equinox.
Only critical operations, like mean need to be forced to full precision.
```python
@eqx.filter_jit
def batched_loss_acc_wrapper(model, batch, batch_sharding, replicated_sharding, key, weight_regularization=0.0):
batch = eqx.filter_shard(batch, batch_sharding)
model = eqx.filter_shard(model, replicated_sharding)
pred = predict_batch(model, batch, False, key)
target = batch["target"]
losses = jax.vmap(model.loss)(pred, target)
acc = jax.vmap(model.acc)(pred, target)
# especially for high batch sizes the mean calculation can overflow, hence force it to full precision.
loss = mpx.force_full_precision(jnp.mean, losses.dtype)(losses)
acc = mpx.force_full_precision(jnp.mean, losses.dtype)(acc)
# weight regularization can help in mixed precision training.
# It keeps the weights small and prevents overflow during matrix multiplication.
params, _ = eqx.partition(model, eqx.is_array)
params = jax.tree_util.tree_leaves(params)
params = jax.tree_util.tree_map(lambda x: x.flatten(), params)
params = jnp.concatenate(params).flatten()
loss = loss + weight_regularization * mpx.force_full_precision(jnp.mean, jnp.float32)(jnp.abs(params))
return loss, acc
```
The same holds for the neural network. Here, only critical operations like layernorm and softmax must be forced to full precision.
This means, as long as your layer does not contain such operations, you can rely on standard `equinox.nn` implementations. For other layers, the only solution so far is to reimplement them and force critical operations to full precision.
```python
class MultiHeadAttentionBlock(eqx.Module):
dense_qs: DenseLayer
dense_ks: DenseLayer
dense_vs: DenseLayer
dense_o: DenseLayer
num_heads: int
dropout: eqx.nn.Dropout
layer_norm: eqx.nn.LayerNorm
...
@staticmethod
def attention(q: Array,
k: Array,
v: Array,
dropout: eqx.nn.Dropout,
key: PRNGKeyArray,
inference: bool) -> Array:
attention_scores = q @ k.T
attention_scores /= jnp.sqrt(q.shape[-1])
# softmax is critical
attention_scores = mpx.force_full_precision(jax.nn.softmax, attention_scores.dtype)(attention_scores, axis=-1)
attention_scores = dropout(attention_scores, inference=inference, key=key)
return attention_scores @ v
def __call__(self, inputs: Array, inference: bool, key: PRNGKeyArray) -> Array:
# also force layernorm to full precision.
inputs_after_layernorm = jax.vmap(mpx.force_full_precision(self.layer_norm, inputs.dtype))(inputs)
qs = jax.vmap(self.dense_qs)(inputs_after_layernorm)
ks = jax.vmap(self.dense_ks)(inputs_after_layernorm)
vs = jax.vmap(self.dense_vs)(inputs_after_layernorm)
qs = es.jax_einshape("n(hf)->hnf", qs, h=self.num_heads)
ks = es.jax_einshape("n(hf)->hnf", ks, h=self.num_heads)
vs = es.jax_einshape("n(hf)->hnf", vs, h=self.num_heads)
keys = jax.random.split(key, self.num_heads)
outputs = jax.vmap(self.attention, in_axes=(0, 0, 0, None, 0, None))(
qs,
ks,
vs,
self.dropout,
keys,
inference)
# reshape outputs (concatenate heads)
outputs = es.jax_einshape("hnf->n(hf)", outputs)
key, key2 = jax.random.split(key)
outputs = jax.vmap(self.dense_o)(outputs)
outputs = self.dropout(outputs, inference=inference, key=key2)
outputs += inputs
return outputs
```
## Citation
To cite this repository, please cite our [paper](https://www.arxiv.org/pdf/2507.03312):
```
@ARTICLE{2025arXiv250703312G,
author = {{Gr{\"a}fe}, Alexander and {Trimpe}, Sebastian},
title = "{MPX: Mixed Precision Training for JAX}",
journal = {arXiv e-prints},
year = 2025,
doi = {10.48550/arXiv.2507.03312},
}
```
## Acknowledgements
We want to thank Partick Kidger for providing equinox and google DeepMind for providing JMP, which was the base for this implementation.
The authors gratefully acknowledge the computing time provided to them at the NHR Center NHR4CES at RWTH Aachen University (project number p0021919). This is funded by the Federal Ministry of Education and Research, and the state governments participating on the basis of the resolutions of the GWK for national high performance computing at universities (www.nhr-verein.de/unsere-partner).
Raw data
{
"_id": null,
"home_page": null,
"name": "mixed-precision-for-JAX",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": "Alexander Graefe <alexander.graefe@dsme.rwth-aachen.de>",
"keywords": "JAX, Mixed Precision, Neural Network",
"author": null,
"author_email": "Alexander Graefe <alexander.graefe@dsme.rwth-aachen.de>",
"download_url": "https://files.pythonhosted.org/packages/07/12/b79f61dea1bce2ab43a8ad3d47c811834acce8597f04aba4187fa3244c68/mixed_precision_for_jax-0.1.8.tar.gz",
"platform": null,
"description": "# mpx: Mixed Precision Training for JAX\n\n\n\n## Installation\nBefore installing the library, please make sure that you installed JAX for your given hardware.\n```console\npip install mixed-precision-for-JAX\n```\n\n## Documentation\nFor basic usage, this README should give you everything you need to know. \nFor deeper insights, you can read the [documentation](https://data-science-in-mechanical-engineering.github.io/mixed_precision_for_JAX/) (https://data-science-in-mechanical-engineering.github.io/mixed_precision_for_JAX/) and our [paper](https://www.arxiv.org/pdf/2507.03312) (https://www.arxiv.org/pdf/2507.03312).\n\n## Introduction\n\nThis repository offers a tool for training JAX models using mixed precision, called **mpx**. It builds upon [JMP](https://github.com/google-deepmind/jmp)\u2014another mixed precision library for JAX\u2014but extends its capabilities. \nJMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). mpx overcomes these limitations, by leveraging Equinox's flexibility to work with any PyTree.\n\n## Basics of Mixed Precision Training\n\nThis section summarizes the original Mixed Precision method from https://developer.nvidia.com/automatic-mixed-precision and https://arxiv.org/pdf/1710.03740.\nMixed Precision training involves performing most of the computations in the forward and backward passes of a neural network using 16-bit floating-point numbers.\nThis approach reduces GPU memory usage by roughly half compared to full precision training, allowing for larger batch sizes or the use of fewer TPUs/GPUs. Additionally, mixed precision can speed up training by decreasing memory access times and utilizing specialized half-precision tensor cores on modern hardware (if available).\n\nOne of the key factors when successfully applying Mixed Precision training is loss scaling. Due to the decreased resolution of float16, small gradients are cast to zero, decreasing training performance. The loss scaling scales the loss by a factor > 1, and as a result the gradients during gradient calculation. Afterwards, the gradients are cast to float32 and then divided by the factor to obtain the original gradient. A standard optimizer then uses the gradient to calculate the model update. The scaling can be chosen automatically with a simple heuristic. If the scaled gradients exceed the range of float16 (i.e., they are inf), we reduce the scaling and do not update the model. If the scaled gradients to not exceed the range of float16 for a longer time, we increase the scaling. \n\nMixed Precision Training hence has the following steps:\n1. Initialize the Model and Optimizer using Full Precision.\n2. Get a Batch from the dataloader.\n3. Cast the batch and model for half precision (e.g., float16 or bfloat16).\n4. Do the forward pass in halfprecision, except critical operations.\n5. Scale the loss.\n6. Calculate the gradient of the scaled loss with respect to the weights.\n8. Cast weights to float32 and divide by the scaling value.\n9. If gradients are infinite, decrease scaling, else, increase scaling if in every n-th epoch.\n10. If gradients are finit do optimizer update, continue with 2.\n\n`mpx` provides important functions for steps 3--9. However, it does not provide a Keras/PyTorch Lightning/Kauldron-like functionality, where you just pass model, loss and optimizer and call run. This is done on purpose to not hurt the low-level approach of JAX and allow users to write their training pipeline like they prefer.\n\n## Main Features\n\n`mpx` provides a comprehensive set of tools for mixed precision training in JAX. \nThe main goal was to keep the library as flexible and as close to `equinox` as possible.\nAs a result, to update a training pipeline to work with mixed precision, one just have to:\n- Update the gradient calculations from `eqx.filter_grad/filter_value_and_grad` to `mpx.filter_grad/filter_value_and_grad`.\n- Do the `optax` optimizer call via `mpx.optimizer_update`.\n- Force critical operations like sum, mean and softmax to full precision using `mpx.force_full_precision`.\nHere are the key components:\n\n### Data Type Management\n- `set_half_precision_datatype(dtype)`: Configure whether to use `float16` or `bfloat16` for half precision training\n- `half_precision_datatype()`: Get the currently configured half precision data type\n\n### Casting Functions\n- `cast_to_half_precision(x: PyTree)`: Cast all JAX arrays in a PyTree to the configured half precision type\n- `cast_to_full_precision(x: PyTree)`: Cast all JAX arrays in a PyTree to `float32`\n- `cast_to_float16(x: PyTree)`: Cast all JAX arrays in a PyTree to `float16`\n- `cast_to_bfloat16(x: PyTree)`: Cast all JAX arrays in a PyTree to `bfloat16`\n- `cast_to_float32(x: PyTree)`: Cast all JAX arrays in a PyTree to `float32`\n\n### Precision Control\n- `force_full_precision`: A decorator that ensures a function performs all calculations in `float32`. This is essential for maintaining numerical stability in operations like mean, sum, and softmax. Currently, this has to be done by hand, i.e., **`mpx` does not identify critical operations and forces the to full precision, like AMP of PyTorch**. This unfortunately also means that provided neural network functions from equinox or flax that include these critical operations (e.g., equinox.nn.MultiheadAttention) must be rewritten by hand (please refer to our example). In a future release, we plan to provide a library that includes typical neural network functions, like Attention, that are ready for mixed precision. \n\n### Loss Scaling\n- `DynamicLossScaling`: A class that manages dynamic loss scaling to prevent underflow in half precision training. It is syntactically equivalent to `jmp.DynamicLossScaling`, however it can scale arbitrary PyTrees.\n - `scale(x)`: Scale a value by the current loss scaling factor\n - `unscale(x)`: Remove the loss scaling factor from a value\n - `adjust(grads_finite)`: Update the loss scaling factor based on gradient stability\nThese functions are just for your information. They are internally used, however these might be interesting for non-standard implementations.\n- `scaled(func, scaling)`: Decorator that applies loss scaling to a function's output\n- `all_finite(tree)`: Check if all values in a PyTree are finite (not NaN or Inf)\n\n### Gradient Computation\n`mpx` provides function decorators for gradient calculations that summarize steps 3--9 in one function call. They have the same meaning and syntax as the corresponding decorators of `equinox`. This means, for an existing training pipeline, one can replace the calls of `equinox.filter_grad/filter_value_and_grad` with `mpx.filter_grad/filter_value_and_grad`\n- `filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False, use_mixed_precision=True)`: Transformation that computes the gradient of func with respect to its first argument using mixed precision with scaling, similar to `equinox.filter_grad`. The transformed function then works as follows:\n 1. If `use_mixed_precision` is True:\n - Casts all input arguments to half precision (float16/bfloat16)\n - Scales the function's output by `scaling`\n 2. Computes gradients using `equinox.filter_grad`\n 3. If `use_mixed_precision` is True:\n - Casts gradients back to full precision (float32)\n - Checks if gradients are finite\n - Updates `scaling` based on whether the gradients are inf or not.\n - Unscales the gradients by dividing with `scaling`\n 4. Returns a tuple containing:\n - The updated `scaling` object\n - A boolean indicating if gradients are finite (needed for optimized step see below)\n - The computed gradients\n - Auxiliary values (if `has_aux=True`)\n\n- `filter_value_and_grad(func, scaling)`: Decorator that works like `filter_grad`, except that it also returns the value.\n\nThe gradient transformations might return gradients that are infinite. In this case, the pipeline needs to skip the model update. For this, `mpx` provides the following function:\n- `optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool)`: Apply optimizer updates only when gradients are finite. Works with arbitrary `optax` optimizers.\n\n## Example\nThe following provides a small example, training a vision transformer on Cifar100 presenting all the important features of `mpx`. For details, please visit examples/train_vit.py.\nThis example will not go into the details for the neural network part, but just the `mpx` relevant parts.\n\nThe example was tested on an RTX4070, the training crashes with a batch size of 256 without mixed precision. With mixed precision, the training runs, demonstrating that mixed precision training via `mpx` effectively reduces the memory used on the GPU. The training speed itself does not change dramatically as the RTX4070 does not have a higher throughput for half precision operations.\n\n### Installation and Execution of the Example\nFirst install JAX for your hardware.\nThen, install all dependencies via\n```bash\npip install -r examples/requirements.txt\n```\nThen you can run the example via. ATTENTION: The script downloads Cifar100.\n```bash\npython -m examples.train_vit\n```\n\n### Explanation\nThe loss scaling has to be initialized during the instantiation of the datasets, models etc. Typically, the initial value is set to the maximum value of `float16`.\n\n```python\n\nloss_scaling = mpx.DynamicLossScaling(loss_scaling=mpx.FLOAT16_MAX, \n min_loss_scaling=jnp.ones((), dtype=jnp.float32), \n period=2000)\n```\nThe loss_scaling object then must be passed to the training pipeline.\n\nThe most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and afterwards call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function, because `loss_scaling` is updated.\n\n```python\n@eqx.filter_jit\ndef make_step(model: eqx.Module, \n optimizer: any, \n optimizer_state: PyTree, \n batch: dict,\n batch_sharding: jax.sharding.NamedSharding,\n replicated_sharding: jax.sharding.NamedSharding,\n loss_scaling: mpx.DynamicLossScaling,\n train_mixed_precicion: bool,\n weight_regularization: Float,\n key: PRNGKeyArray\n ) -> tuple[eqx.Module, PyTree, Float, PRNGKeyArray]:\n batch = eqx.filter_shard(batch, batch_sharding)\n model = eqx.filter_shard(model, replicated_sharding)\n optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)\n \n\n if train_mixed_precicion:\n # this is the critical part\n (loss_value, _), loss_scaling, grads_finite, grads = mpx.filter_value_and_grad(batched_loss_acc_wrapper, scaling=loss_scaling, has_aux=True)(\n model, batch, batch_sharding, replicated_sharding, key, weight_regularization)\n model, optimizer_state = mpx.optimizer_update(model, optimizer, optimizer_state, grads,grads_finite)\n else:\n (loss_value, _), grads = eqx.filter_value_and_grad(batched_loss_acc_wrapper, has_aux=True)(\n model, batch, batch_sharding, replicated_sharding, key)\n # optimizer step\n updates, optimizer_state = optimizer.update(\n grads, optimizer_state, eqx.filter(model, eqx.is_array)\n )\n model = eqx.apply_updates(model, updates)\n\n model = eqx.filter_shard(model, replicated_sharding)\n optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)\n loss_scaling = eqx.filter_shard(loss_scaling, replicated_sharding)\n \n # return loss_scaling as it is changed\n return model, optimizer_state, loss_scaling, loss_value\n```\n\nThrough the transformation via `mpx.filter_value_and_grad`, we can write our loss function as we normally do when using JAX/Equinox.\nOnly critical operations, like mean need to be forced to full precision.\n```python\n@eqx.filter_jit\ndef batched_loss_acc_wrapper(model, batch, batch_sharding, replicated_sharding, key, weight_regularization=0.0):\n batch = eqx.filter_shard(batch, batch_sharding)\n model = eqx.filter_shard(model, replicated_sharding)\n\n pred = predict_batch(model, batch, False, key)\n\n target = batch[\"target\"] \n losses = jax.vmap(model.loss)(pred, target)\n acc = jax.vmap(model.acc)(pred, target)\n\n # especially for high batch sizes the mean calculation can overflow, hence force it to full precision.\n loss = mpx.force_full_precision(jnp.mean, losses.dtype)(losses)\n acc = mpx.force_full_precision(jnp.mean, losses.dtype)(acc)\n\n # weight regularization can help in mixed precision training.\n # It keeps the weights small and prevents overflow during matrix multiplication.\n params, _ = eqx.partition(model, eqx.is_array)\n params = jax.tree_util.tree_leaves(params)\n params = jax.tree_util.tree_map(lambda x: x.flatten(), params)\n params = jnp.concatenate(params).flatten()\n\n loss = loss + weight_regularization * mpx.force_full_precision(jnp.mean, jnp.float32)(jnp.abs(params))\n\n return loss, acc\n```\n\nThe same holds for the neural network. Here, only critical operations like layernorm and softmax must be forced to full precision.\nThis means, as long as your layer does not contain such operations, you can rely on standard `equinox.nn` implementations. For other layers, the only solution so far is to reimplement them and force critical operations to full precision.\n\n```python\nclass MultiHeadAttentionBlock(eqx.Module):\n dense_qs: DenseLayer\n dense_ks: DenseLayer\n dense_vs: DenseLayer\n\n dense_o: DenseLayer\n\n num_heads: int\n\n dropout: eqx.nn.Dropout\n layer_norm: eqx.nn.LayerNorm\n\n ...\n\n @staticmethod\n def attention(q: Array,\n k: Array,\n v: Array,\n dropout: eqx.nn.Dropout,\n key: PRNGKeyArray, \n inference: bool) -> Array:\n attention_scores = q @ k.T\n attention_scores /= jnp.sqrt(q.shape[-1])\n\n # softmax is critical\n attention_scores = mpx.force_full_precision(jax.nn.softmax, attention_scores.dtype)(attention_scores, axis=-1)\n\n attention_scores = dropout(attention_scores, inference=inference, key=key)\n return attention_scores @ v\n\n def __call__(self, inputs: Array, inference: bool, key: PRNGKeyArray) -> Array:\n # also force layernorm to full precision.\n inputs_after_layernorm = jax.vmap(mpx.force_full_precision(self.layer_norm, inputs.dtype))(inputs)\n qs = jax.vmap(self.dense_qs)(inputs_after_layernorm)\n ks = jax.vmap(self.dense_ks)(inputs_after_layernorm)\n vs = jax.vmap(self.dense_vs)(inputs_after_layernorm)\n\n qs = es.jax_einshape(\"n(hf)->hnf\", qs, h=self.num_heads)\n ks = es.jax_einshape(\"n(hf)->hnf\", ks, h=self.num_heads)\n vs = es.jax_einshape(\"n(hf)->hnf\", vs, h=self.num_heads)\n\n keys = jax.random.split(key, self.num_heads)\n\n outputs = jax.vmap(self.attention, in_axes=(0, 0, 0, None, 0, None))(\n qs, \n ks,\n vs,\n self.dropout,\n keys,\n inference)\n\n # reshape outputs (concatenate heads)\n outputs = es.jax_einshape(\"hnf->n(hf)\", outputs)\n\n key, key2 = jax.random.split(key)\n outputs = jax.vmap(self.dense_o)(outputs)\n outputs = self.dropout(outputs, inference=inference, key=key2)\n outputs += inputs\n\n return outputs\n```\n\n## Citation\n\nTo cite this repository, please cite our [paper](https://www.arxiv.org/pdf/2507.03312):\n\n```\n@ARTICLE{2025arXiv250703312G,\n author = {{Gr{\\\"a}fe}, Alexander and {Trimpe}, Sebastian},\n title = \"{MPX: Mixed Precision Training for JAX}\",\n journal = {arXiv e-prints},\n year = 2025,\n doi = {10.48550/arXiv.2507.03312},\n}\n\n\n``` \n\n## Acknowledgements\nWe want to thank Partick Kidger for providing equinox and google DeepMind for providing JMP, which was the base for this implementation.\n\nThe authors gratefully acknowledge the computing time provided to them at the NHR Center NHR4CES at RWTH Aachen University (project number p0021919). This is funded by the Federal Ministry of Education and Research, and the state governments participating on the basis of the resolutions of the GWK for national high performance computing at universities (www.nhr-verein.de/unsere-partner).\n\n\n",
"bugtrack_url": null,
"license": null,
"summary": "A toolbox for mixed precision training via JAX.",
"version": "0.1.8",
"project_urls": {
"Bug Tracker": "https://github.com/AlexGraefe/mixed_precision_for_JAX/issues",
"Repository": "https://github.com/AlexGraefe/mixed_precision_for_JAX"
},
"split_keywords": [
"jax",
" mixed precision",
" neural network"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "0712b79f61dea1bce2ab43a8ad3d47c811834acce8597f04aba4187fa3244c68",
"md5": "2158bb1c5acc4f0c202fc28eac830719",
"sha256": "d705232b2d20221c005db06595e7cee72d2ad5c4941a39071ec2677d07bd0629"
},
"downloads": -1,
"filename": "mixed_precision_for_jax-0.1.8.tar.gz",
"has_sig": false,
"md5_digest": "2158bb1c5acc4f0c202fc28eac830719",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 13763,
"upload_time": "2025-07-14T10:56:33",
"upload_time_iso_8601": "2025-07-14T10:56:33.565109Z",
"url": "https://files.pythonhosted.org/packages/07/12/b79f61dea1bce2ab43a8ad3d47c811834acce8597f04aba4187fa3244c68/mixed_precision_for_jax-0.1.8.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-14 10:56:33",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "AlexGraefe",
"github_project": "mixed_precision_for_JAX",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "jax",
"specs": []
},
{
"name": "jaxlib",
"specs": []
},
{
"name": "jaxtyping",
"specs": []
},
{
"name": "ml_dtypes",
"specs": []
},
{
"name": "numpy",
"specs": []
},
{
"name": "opt_einsum",
"specs": []
},
{
"name": "optax",
"specs": []
},
{
"name": "scipy",
"specs": []
},
{
"name": "equinox",
"specs": []
},
{
"name": "sphinx",
"specs": []
},
{
"name": "pydata-sphinx-theme",
"specs": []
}
],
"lcname": "mixed-precision-for-jax"
}