flashbax


Nameflashbax JSON
Version 0.1.0 PyPI version JSON
download
home_page
SummaryFlashbax is an experience replay library oriented around JAX. Tailored to integrate seamlessly with JAX's Just-In-Time (JIT) compilation.
upload_time2024-02-06 10:33:04
maintainer
docs_urlNone
author
requires_python>=3.9
license
keywords jax memory python reinforcement-learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            

<p align="center">
    <a href="./docs/img/logo.png#gh-light-mode-only">
        <img src="./docs/imgs/logo.png#gh-light-mode-only" alt="Flashbax Logo" width="70%"/>
    </a>
    <a href="./docs/img/logo_dm.png#gh-dark-mode-only">
        <img src="./docs/imgs/logo_dm.png#gh-dark-mode-only" alt="Flashbax Logo" width="70%"/>
    </a>
</p>

[![Python Versions](https://img.shields.io/pypi/pyversions/flashbax.svg?style=flat-square)](https://www.python.org/doc/versions/)
[![PyPI Version](https://badge.fury.io/py/flashbax.svg)](https://badge.fury.io/py/flashbax)
[![Tests](https://github.com/instadeepai/flashbax/actions/workflows/tests_linters.yml/badge.svg)](https://github.com/instadeepai/flashbax/actions/workflows/tests_linters.yml)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![MyPy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
[![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://opensource.org/licenses/Apache-2.0)

<div align="center">
    <h3>
      <a href="#overview-">Overview</a> |
      <a href="#features-%EF%B8%8F">Features</a> |
      <a href="#setup-">Setup</a> |
      <a href="#quickstart-">Quick Start</a> |
      <a href="#examples-">Examples</a> |
      <a href="#important-considerations-%EF%B8%8F">Important Considerations</a> |
      <a href="#benchmarks-">Benchmarks</a> |
      <a href="#contributing-">Contributing</a> |
      <a href="#see-also-">See Also</a> |
      <a href="#citing">Citing</a> |
      <a href="#acknowledgements-">Acknowledgments</a>
    </h3>
</div>


---

# โšก High Speed Buffers In Jax โšก

## Overview ๐Ÿ”

Flashbax is a library designed to streamline the use of experience replay buffers within the context of reinforcement learning (RL). Tailored specifically for compatibility with the JAX paradigm, Flashbax allows these buffers to be easily utilised within fully compiled functions and training loops.

Flashbax provides an implementation of various different types of buffers, such as Flat Buffer, Trajectory Buffer, and Prioritised variants of both. Whether for academic research, industrial applications, or personal projects, Flashbax offers a simple and flexible framework for RL experience replay handling.

## Features ๐Ÿ› ๏ธ

๐Ÿš€ **Efficient Buffer Variants**: All Flashbax buffers are built as specialised variants of the trajectory buffer, optimising memory usage and functionality across various types of buffers.

๐Ÿ—„๏ธ **Flat Buffer**: The Flat Buffer, akin to the transition buffer used in algorithms like DQN, is a core component. It employs a sequence of 2 (i.e. $s_t$, $s_{t+1}$), with a period of 1 for comprehensive transition pair consideration.

๐Ÿงบ **Item Buffer**: The Item Buffer is a simple buffer that stores individual items. It is useful for storing data that is independent of each other, such as (observation, action, reward, discount, next_observation) tuples, or entire episodes.

๐Ÿ›ค๏ธ **Trajectory Buffer**: The Trajectory Buffer facilitates the sampling of multi-step trajectories, catering to algorithms utilising recurrent networks like R2D2 (Kapturowski et al., [2018](https://www.deepmind.com/publications/recurrent-experience-replay-in-distributed-reinforcement-learning)).

๐Ÿ… **Prioritised Buffers**: Both Flat and Trajectory Buffers can be prioritised, enabling sampling based on user-defined priorities. The prioritisation mechanism aligns with the principles outlined in the PER paper (Schaul et al, [2016](https://arxiv.org/abs/1511.05952)).

๐Ÿšถ **Trajectory/Flat Queue**: A queue data structure is provided where one is expected to sample data in a FIFO order. The queue can be used for on-policy algorithms with specific use cases.

## Setup ๐ŸŽฌ

To integrate Flashbax into your project, follow these steps:

1. **Installation**: Begin by installing Flashbax using `pip`:
```bash
pip install flashbax
```

2. **Selecting Buffers**: Choose from a range of buffer options, including Flat Buffer, Trajectory Buffer, and Prioritised variants.
```python
import flashbax as fbx

buffer = fbx.make_trajectory_buffer(...)
# OR
buffer = fbx.make_prioritised_trajectory_buffer(...)
# OR
buffer = fbx.make_flat_buffer(...)
# OR
buffer = fbx.make_prioritised_flat_buffer(...)
# OR
buffer = fbx.make_item_buffer(...)
# OR
buffer = fbx.make_trajectory_queue(...)

# Initialise
state = buffer.init(example_timestep)
# Add Data
state = buffer.add(state, example_data)
# Sample Data
batch = buffer.sample(state, rng_key)
```

## Quickstart ๐Ÿ

Below we provide a minimal code example for using the flat buffer. In this example, we show how each of the pure functions defining the flat buffer may be used. We note that each of these pure functions is compatible with `jax.pmap` and `jax.jit`, but for simplicity, these are not used in the below example.

```python
import jax
import jax.numpy as jnp
import flashbax as fbx

# Instantiate the flat buffer NamedTuple using `make_flat_buffer` using a simple configuration.
# The returned `buffer` is simply a container for the pure functions needed for using a flat buffer.
buffer = fbx.make_flat_buffer(max_length=32, min_length=2, sample_batch_size=1)

# Initialise the buffer's state.
fake_timestep = {"obs": jnp.array([0, 0]), "reward": jnp.array(0.0)}
state = buffer.init(fake_timestep)

# Now we add data to the buffer.
state = buffer.add(state, {"obs": jnp.array([1, 2]), "reward": jnp.array(3.0)})
print(buffer.can_sample(state))  # False because min_length not reached yet.

state = buffer.add(state, {"obs": jnp.array([4, 5]), "reward": jnp.array(6.0)})
print(buffer.can_sample(state))  # Still False because we need 2 *transitions* (i.e. 3 timesteps).

state = buffer.add(state, {"obs": jnp.array([7, 8]), "reward": jnp.array(9.0)})
print(buffer.can_sample(state))  # True! We have 2 transitions (3 timesteps).

# Get a transition from the buffer.
rng_key = jax.random.PRNGKey(0)  # Source of randomness.
batch = buffer.sample(state, rng_key)  # Sample

# We have a transition! Prints: obs = [[4 5]], obs' = [[7 8]]
print(
    f"obs = {batch.experience.first['obs']}, obs' = {batch.experience.second['obs']}"
)
```

## Examples ๐Ÿง‘โ€๐Ÿ’ป

We provide the following Colab examples for a more advanced tutorial on how to use each of the flashbax buffers as well as usage examples:

| Colab Notebook | Description |
|----------------|-------------|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_flat_buffer.ipynb) | Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_trajectory_buffer.ipynb) | Trajectory Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_prioritised_flat_buffer.ipynb) | Prioritised Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_dqn_example.ipynb) | Anakin DQN |
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_prioritised_dqn_example.ipynb) | Anakin Prioritised DQN |
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_ppo_example.ipynb) | Anakin PPO |
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/gym_dqn_example.ipynb) | DQN with Vectorised Gym Environments |

- ๐Ÿ‘พ [Anakin](https://arxiv.org/abs/2104.06272) - JAX based architecture for jit compiling the training
of RL agents end-to-end.
- ๐ŸŽฎ [DQN](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py) - implementation adapted
from CleanRLs DQN JAX example.
- ๐ŸฆŽ [Jumanji](https://github.com/instadeepai/jumanji/) - utilise Jumanji's JAX based environments
like Snake for our fully jitted examples.

## Vault ๐Ÿ’พ
Vault is an efficient mechanism for saving Flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers.

For more information, see the demonstrative notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb)


## Important Considerations โš ๏ธ

When working with Flashbax buffers, it's crucial to be mindful of certain considerations to ensure the proper functionality of your RL agent.

### Sequential Data Addition
Flashbax uses a trajectory buffer as the foundation for all buffer types. This means that data must be added sequentially. Specifically, for the flat buffer, each added timestep must be followed by its consecutive timestep. In most scenarios, this requirement is naturally satisfied and doesn't demand extensive consideration. However, it's essential to be aware of this constraint, especially when adding batches of data that are completely independent of each other. Failing to maintain the sequence relationship between timesteps can lead to algorithmic issues. The user is expected to handle the case of final to first timestep. This happens when going from episode `n` to episode `n+1` in the same batch. For example, we utilise auto reset wrappers to automatically reset the environment upon a terminal timestep. Additionally, we utilise discount values (1 for non-terminal state, 0 for terminal state) to mask the value function and discounting of rewards accordingly.

### Effective Buffer Size
When adding batches of data, the buffer is created in a block-like structure. This means that the effective buffer size is dependent on the size of the batch dimension. The trajectory buffer allows a user to specify the add batch dimension and the max length of the time axis. This will create a block structure of (batch, time) allowing the maximum number of timesteps that can be in storage to be batch*time. For ease of use, we provide the max size argument that allows a user to set their total desired number of timesteps and we calculate the max length of the time axis dependent on the add batch dimension that is provided. Due to this, it is important to note that when using the max size argument, the max length of the time axis will be equal to max size // add batch size which will round down thereby reducing the effective buffer size. This means one might think they are increasing the buffer size by a certain amount but in actuality there is no increase. Therefore, to avoid this, we recommend one of two things: Use the max length time axis argument explicitly or increase the max size argument in multiples of the add batch size.

### Handling Episode Truncation
Another critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately. It can be difficult handling truncation when using the flat buffer or trajectory buffer as the algorithm must handle the case of the final timestep in an episode being followed by the first timestep in the next episode. Sacrificing memory efficiency for ease of use, the item buffer can be used to store transitions or entire trajectories independently. This means that the algorithm does not need to handle the case of the final timestep in an episode being followed by the first timestep in the next episode as only the data that is explicitly inserted can be sampled.

### Independent Data Usage
For situations where you intend to utilise buffers with data that lack sequential information, you can leverage the item buffer which is a wrapped trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising the item buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items.

### In-place Updating of Buffer State
Since buffers are generally large and occupy a significant portion of device memory, it is beneficial to perform in-place updates. To do this, it is important to specify to the top-level compiled function that you would like to perform this in-place update operation. This is indicated as follows:


```python
def train(train_state, buffer_state):
    ...
    return train_state, buffer_state

# Initialise the buffer state
buffer_fn = fbx.make_trajectory_buffer(...)
buffer_state = buffer_fn.init(example_timestep)

# Initialise some training state
train_state = train_state.init(...)

# Compile the training function and specify the donation of the buffer state argument
train_state, buffer_state = jax.jit(train, donate_argnums=(1,))(
    train_state, buffer_state
)
```

It is important to include `donate_argnums` when calling `jax.jit` to enable JAX to perform an in-place update of the replay buffer state. Omitting `donate_argnums` would force JAX to create a copy of the state for any modifications to the replay buffer state, potentially negating all performance benefits. More information about buffer donation in JAX can be found in the [documentation](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).


### Storing Data with Vault
As mentioned above, Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data.

In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers.

## Benchmarks ๐Ÿ“ˆ

Here we provide a series of initial benchmarks outlining the performance of the various Flashbax buffers compared against commonly used open-source buffers. In these benchmarks we (unless explicitly stated otherwise) use the following configuration:

| Parameter               | Value       |
|-------------------------|-------------|
| Buffer Size             | 500_000     |
| Sample Batch Size       | 256         |
| Observation Size        | (32, 32, 3) |
| Add Sequence Length     | 1           |
| Add Sequence Batch Size | 1           |
| Sample Sequence Length  | 1           |
| Sample Sequence Period  | 1           |

The reason we use a sample sequence length and period of 1 is to directly compare to the other buffers, this means the speeds for the trajectory buffer are comparable to the speeds of the item buffer as the item buffer is simply a wrapped trajectory buffer with this configuration. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size).

### CPU Speeds

<p float="left">
<img alt="CPU_Add" src="docs/imgs/cpu_add.png" width="49%">
<img alt="CPU_Sample" src="docs/imgs/cpu_sample.png" width="49%">
</p>

### TPU Speeds
<p float="left">
<img alt="TPU_Add" src="docs/imgs/tpu_add.png" width="49%">
<img alt="TPU_Sample" src="docs/imgs/tpu_sample.png" width="49%">
</p>

### GPU Speeds

We notice strange behaviour with the GPU speeds when adding data. We believe this is due to the fact that certain JAX operations are not yet fully optimised for GPU usage as we see Dejax has the same performance issues. We expect these speeds to improve in the future.

<p float="left">
<img alt="GPU_Add" src="docs/imgs/gpu_add.png" width="49%">
<img alt="GPU_Sample" src="docs/imgs/gpu_sample.png" width="49%">
</p>

### CPU, GPU, & TPU Adding Batches
Previous benchmarks added only a single timestep at a time, we now evaluate adding batches of 128 timesteps at a time - a feature that most would use in high-throughput RL. We only compare to the buffers which have this capability.

<p float="left">
<img alt="CPU_Add_Batch" src="docs/imgs/cpu_add_batch.png" width="49%">
<img alt="TPU_Add_Batch" src="docs/imgs/tpu_add_batch.png" width="49%">
</p>

<p align="center">
<img alt="GPU_Add_Batch" src="docs/imgs/gpu_add_batch.png" width="49%">
</p>

Ultimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called.


## Contributing ๐Ÿค

Contributions are welcome! See our issue tracker for
[good first issues](https://github.com/instadeepai/flashbax/labels/good%20first%20issue). Please read
our [contributing guidelines](https://github.com/instadeepai/flashbax/blob/main/CONTRIBUTING.md) for
details on how to submit pull requests, our Contributor License Agreement, and community guidelines.

## See Also ๐Ÿ“š
Checkout some of the other buffer libraries from the community that we have highlighted in our
benchmarks.

- ๐Ÿ“€ [Dejax](https://github.com/hr0nix/dejax): the first library to provide a JAX-compatible replay buffers.
- ๐ŸŽถ [Reverb](https://github.com/google-deepmind/reverb): efficient replay buffers used for both local and large-scale distributed RL.
- ๐Ÿฐ [Dopamine](https://github.com/google/dopamine/blob/master/dopamine/replay_memory/): research framework for fast prototyping, providing several core replay buffers.
- ๐Ÿค– [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/): suite of reliable RL baselines with its own, easy-to-use replay buffers.

## Citing Flashbax โœ๏ธ

If you use Flashbax in your work, please cite the library using:

```
@misc{flashbax,
    title={Flashbax: Streamlining Experience Replay Buffers for Reinforcement Learning with JAX},
    author={Edan Toledo and Laurence Midgley and Donal Byrne and Callum Rhys Tilbury and
    Matthew Macfarlane and Cyprien Courtot and Alexandre Laterre},
    year={2023},
    url={https://github.com/instadeepai/flashbax/},
}
```

## Acknowledgements ๐Ÿ™

The development of this library was supported with Cloud TPUs
from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC) ๐ŸŒค.

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "flashbax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "jax,memory,python,reinforcement-learning",
    "author": "",
    "author_email": "InstaDeep <hello@instadeep.com>",
    "download_url": "https://files.pythonhosted.org/packages/7b/de/0b53b993ae8c2a96fe84dd8641cf9b889f19152b67d64b3913c7692dc812/flashbax-0.1.0.tar.gz",
    "platform": null,
    "description": "\n\n<p align=\"center\">\n    <a href=\"./docs/img/logo.png#gh-light-mode-only\">\n        <img src=\"./docs/imgs/logo.png#gh-light-mode-only\" alt=\"Flashbax Logo\" width=\"70%\"/>\n    </a>\n    <a href=\"./docs/img/logo_dm.png#gh-dark-mode-only\">\n        <img src=\"./docs/imgs/logo_dm.png#gh-dark-mode-only\" alt=\"Flashbax Logo\" width=\"70%\"/>\n    </a>\n</p>\n\n[![Python Versions](https://img.shields.io/pypi/pyversions/flashbax.svg?style=flat-square)](https://www.python.org/doc/versions/)\n[![PyPI Version](https://badge.fury.io/py/flashbax.svg)](https://badge.fury.io/py/flashbax)\n[![Tests](https://github.com/instadeepai/flashbax/actions/workflows/tests_linters.yml/badge.svg)](https://github.com/instadeepai/flashbax/actions/workflows/tests_linters.yml)\n[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n[![MyPy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)\n[![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://opensource.org/licenses/Apache-2.0)\n\n<div align=\"center\">\n    <h3>\n      <a href=\"#overview-\">Overview</a> |\n      <a href=\"#features-%EF%B8%8F\">Features</a> |\n      <a href=\"#setup-\">Setup</a> |\n      <a href=\"#quickstart-\">Quick Start</a> |\n      <a href=\"#examples-\">Examples</a> |\n      <a href=\"#important-considerations-%EF%B8%8F\">Important Considerations</a> |\n      <a href=\"#benchmarks-\">Benchmarks</a> |\n      <a href=\"#contributing-\">Contributing</a> |\n      <a href=\"#see-also-\">See Also</a> |\n      <a href=\"#citing\">Citing</a> |\n      <a href=\"#acknowledgements-\">Acknowledgments</a>\n    </h3>\n</div>\n\n\n---\n\n# \u26a1 High Speed Buffers In Jax \u26a1\n\n## Overview \ud83d\udd0d\n\nFlashbax is a library designed to streamline the use of experience replay buffers within the context of reinforcement learning (RL). Tailored specifically for compatibility with the JAX paradigm, Flashbax allows these buffers to be easily utilised within fully compiled functions and training loops.\n\nFlashbax provides an implementation of various different types of buffers, such as Flat Buffer, Trajectory Buffer, and Prioritised variants of both. Whether for academic research, industrial applications, or personal projects, Flashbax offers a simple and flexible framework for RL experience replay handling.\n\n## Features \ud83d\udee0\ufe0f\n\n\ud83d\ude80 **Efficient Buffer Variants**: All Flashbax buffers are built as specialised variants of the trajectory buffer, optimising memory usage and functionality across various types of buffers.\n\n\ud83d\uddc4\ufe0f **Flat Buffer**: The Flat Buffer, akin to the transition buffer used in algorithms like DQN, is a core component. It employs a sequence of 2 (i.e. $s_t$, $s_{t+1}$), with a period of 1 for comprehensive transition pair consideration.\n\n\ud83e\uddfa **Item Buffer**: The Item Buffer is a simple buffer that stores individual items. It is useful for storing data that is independent of each other, such as (observation, action, reward, discount, next_observation) tuples, or entire episodes.\n\n\ud83d\udee4\ufe0f **Trajectory Buffer**: The Trajectory Buffer facilitates the sampling of multi-step trajectories, catering to algorithms utilising recurrent networks like R2D2 (Kapturowski et al., [2018](https://www.deepmind.com/publications/recurrent-experience-replay-in-distributed-reinforcement-learning)).\n\n\ud83c\udfc5 **Prioritised Buffers**: Both Flat and Trajectory Buffers can be prioritised, enabling sampling based on user-defined priorities. The prioritisation mechanism aligns with the principles outlined in the PER paper (Schaul et al, [2016](https://arxiv.org/abs/1511.05952)).\n\n\ud83d\udeb6 **Trajectory/Flat Queue**: A queue data structure is provided where one is expected to sample data in a FIFO order. The queue can be used for on-policy algorithms with specific use cases.\n\n## Setup \ud83c\udfac\n\nTo integrate Flashbax into your project, follow these steps:\n\n1. **Installation**: Begin by installing Flashbax using `pip`:\n```bash\npip install flashbax\n```\n\n2. **Selecting Buffers**: Choose from a range of buffer options, including Flat Buffer, Trajectory Buffer, and Prioritised variants.\n```python\nimport flashbax as fbx\n\nbuffer = fbx.make_trajectory_buffer(...)\n# OR\nbuffer = fbx.make_prioritised_trajectory_buffer(...)\n# OR\nbuffer = fbx.make_flat_buffer(...)\n# OR\nbuffer = fbx.make_prioritised_flat_buffer(...)\n# OR\nbuffer = fbx.make_item_buffer(...)\n# OR\nbuffer = fbx.make_trajectory_queue(...)\n\n# Initialise\nstate = buffer.init(example_timestep)\n# Add Data\nstate = buffer.add(state, example_data)\n# Sample Data\nbatch = buffer.sample(state, rng_key)\n```\n\n## Quickstart \ud83c\udfc1\n\nBelow we provide a minimal code example for using the flat buffer. In this example, we show how each of the pure functions defining the flat buffer may be used. We note that each of these pure functions is compatible with `jax.pmap` and `jax.jit`, but for simplicity, these are not used in the below example.\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport flashbax as fbx\n\n# Instantiate the flat buffer NamedTuple using `make_flat_buffer` using a simple configuration.\n# The returned `buffer` is simply a container for the pure functions needed for using a flat buffer.\nbuffer = fbx.make_flat_buffer(max_length=32, min_length=2, sample_batch_size=1)\n\n# Initialise the buffer's state.\nfake_timestep = {\"obs\": jnp.array([0, 0]), \"reward\": jnp.array(0.0)}\nstate = buffer.init(fake_timestep)\n\n# Now we add data to the buffer.\nstate = buffer.add(state, {\"obs\": jnp.array([1, 2]), \"reward\": jnp.array(3.0)})\nprint(buffer.can_sample(state))  # False because min_length not reached yet.\n\nstate = buffer.add(state, {\"obs\": jnp.array([4, 5]), \"reward\": jnp.array(6.0)})\nprint(buffer.can_sample(state))  # Still False because we need 2 *transitions* (i.e. 3 timesteps).\n\nstate = buffer.add(state, {\"obs\": jnp.array([7, 8]), \"reward\": jnp.array(9.0)})\nprint(buffer.can_sample(state))  # True! We have 2 transitions (3 timesteps).\n\n# Get a transition from the buffer.\nrng_key = jax.random.PRNGKey(0)  # Source of randomness.\nbatch = buffer.sample(state, rng_key)  # Sample\n\n# We have a transition! Prints: obs = [[4 5]], obs' = [[7 8]]\nprint(\n    f\"obs = {batch.experience.first['obs']}, obs' = {batch.experience.second['obs']}\"\n)\n```\n\n## Examples \ud83e\uddd1\u200d\ud83d\udcbb\n\nWe provide the following Colab examples for a more advanced tutorial on how to use each of the flashbax buffers as well as usage examples:\n\n| Colab Notebook | Description |\n|----------------|-------------|\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_flat_buffer.ipynb) | Flat Buffer Quickstart|\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_trajectory_buffer.ipynb) | Trajectory Buffer Quickstart|\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_prioritised_flat_buffer.ipynb) | Prioritised Flat Buffer Quickstart|\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_dqn_example.ipynb) | Anakin DQN |\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_prioritised_dqn_example.ipynb) | Anakin Prioritised DQN |\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_ppo_example.ipynb) | Anakin PPO |\n| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/gym_dqn_example.ipynb) | DQN with Vectorised Gym Environments |\n\n- \ud83d\udc7e [Anakin](https://arxiv.org/abs/2104.06272) - JAX based architecture for jit compiling the training\nof RL agents end-to-end.\n- \ud83c\udfae [DQN](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py) - implementation adapted\nfrom CleanRLs DQN JAX example.\n- \ud83e\udd8e [Jumanji](https://github.com/instadeepai/jumanji/) - utilise Jumanji's JAX based environments\nlike Snake for our fully jitted examples.\n\n## Vault \ud83d\udcbe\nVault is an efficient mechanism for saving Flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers.\n\nFor more information, see the demonstrative notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb)\n\n\n## Important Considerations \u26a0\ufe0f\n\nWhen working with Flashbax buffers, it's crucial to be mindful of certain considerations to ensure the proper functionality of your RL agent.\n\n### Sequential Data Addition\nFlashbax uses a trajectory buffer as the foundation for all buffer types. This means that data must be added sequentially. Specifically, for the flat buffer, each added timestep must be followed by its consecutive timestep. In most scenarios, this requirement is naturally satisfied and doesn't demand extensive consideration. However, it's essential to be aware of this constraint, especially when adding batches of data that are completely independent of each other. Failing to maintain the sequence relationship between timesteps can lead to algorithmic issues. The user is expected to handle the case of final to first timestep. This happens when going from episode `n` to episode `n+1` in the same batch. For example, we utilise auto reset wrappers to automatically reset the environment upon a terminal timestep. Additionally, we utilise discount values (1 for non-terminal state, 0 for terminal state) to mask the value function and discounting of rewards accordingly.\n\n### Effective Buffer Size\nWhen adding batches of data, the buffer is created in a block-like structure. This means that the effective buffer size is dependent on the size of the batch dimension. The trajectory buffer allows a user to specify the add batch dimension and the max length of the time axis. This will create a block structure of (batch, time) allowing the maximum number of timesteps that can be in storage to be batch*time. For ease of use, we provide the max size argument that allows a user to set their total desired number of timesteps and we calculate the max length of the time axis dependent on the add batch dimension that is provided. Due to this, it is important to note that when using the max size argument, the max length of the time axis will be equal to max size // add batch size which will round down thereby reducing the effective buffer size. This means one might think they are increasing the buffer size by a certain amount but in actuality there is no increase. Therefore, to avoid this, we recommend one of two things: Use the max length time axis argument explicitly or increase the max size argument in multiples of the add batch size.\n\n### Handling Episode Truncation\nAnother critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately. It can be difficult handling truncation when using the flat buffer or trajectory buffer as the algorithm must handle the case of the final timestep in an episode being followed by the first timestep in the next episode. Sacrificing memory efficiency for ease of use, the item buffer can be used to store transitions or entire trajectories independently. This means that the algorithm does not need to handle the case of the final timestep in an episode being followed by the first timestep in the next episode as only the data that is explicitly inserted can be sampled.\n\n### Independent Data Usage\nFor situations where you intend to utilise buffers with data that lack sequential information, you can leverage the item buffer which is a wrapped trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising the item buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items.\n\n### In-place Updating of Buffer State\nSince buffers are generally large and occupy a significant portion of device memory, it is beneficial to perform in-place updates. To do this, it is important to specify to the top-level compiled function that you would like to perform this in-place update operation. This is indicated as follows:\n\n\n```python\ndef train(train_state, buffer_state):\n    ...\n    return train_state, buffer_state\n\n# Initialise the buffer state\nbuffer_fn = fbx.make_trajectory_buffer(...)\nbuffer_state = buffer_fn.init(example_timestep)\n\n# Initialise some training state\ntrain_state = train_state.init(...)\n\n# Compile the training function and specify the donation of the buffer state argument\ntrain_state, buffer_state = jax.jit(train, donate_argnums=(1,))(\n    train_state, buffer_state\n)\n```\n\nIt is important to include `donate_argnums` when calling `jax.jit` to enable JAX to perform an in-place update of the replay buffer state. Omitting `donate_argnums` would force JAX to create a copy of the state for any modifications to the replay buffer state, potentially negating all performance benefits. More information about buffer donation in JAX can be found in the [documentation](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).\n\n\n### Storing Data with Vault\nAs mentioned above, Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\\tau$, then we must save to the vault every $\\tau - 1$ steps, lest we overwrite (and lose) unsaved data.\n\nIn summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers.\n\n## Benchmarks \ud83d\udcc8\n\nHere we provide a series of initial benchmarks outlining the performance of the various Flashbax buffers compared against commonly used open-source buffers. In these benchmarks we (unless explicitly stated otherwise) use the following configuration:\n\n| Parameter               | Value       |\n|-------------------------|-------------|\n| Buffer Size             | 500_000     |\n| Sample Batch Size       | 256         |\n| Observation Size        | (32, 32, 3) |\n| Add Sequence Length     | 1           |\n| Add Sequence Batch Size | 1           |\n| Sample Sequence Length  | 1           |\n| Sample Sequence Period  | 1           |\n\nThe reason we use a sample sequence length and period of 1 is to directly compare to the other buffers, this means the speeds for the trajectory buffer are comparable to the speeds of the item buffer as the item buffer is simply a wrapped trajectory buffer with this configuration. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size).\n\n### CPU Speeds\n\n<p float=\"left\">\n<img alt=\"CPU_Add\" src=\"docs/imgs/cpu_add.png\" width=\"49%\">\n<img alt=\"CPU_Sample\" src=\"docs/imgs/cpu_sample.png\" width=\"49%\">\n</p>\n\n### TPU Speeds\n<p float=\"left\">\n<img alt=\"TPU_Add\" src=\"docs/imgs/tpu_add.png\" width=\"49%\">\n<img alt=\"TPU_Sample\" src=\"docs/imgs/tpu_sample.png\" width=\"49%\">\n</p>\n\n### GPU Speeds\n\nWe notice strange behaviour with the GPU speeds when adding data. We believe this is due to the fact that certain JAX operations are not yet fully optimised for GPU usage as we see Dejax has the same performance issues. We expect these speeds to improve in the future.\n\n<p float=\"left\">\n<img alt=\"GPU_Add\" src=\"docs/imgs/gpu_add.png\" width=\"49%\">\n<img alt=\"GPU_Sample\" src=\"docs/imgs/gpu_sample.png\" width=\"49%\">\n</p>\n\n### CPU, GPU, & TPU Adding Batches\nPrevious benchmarks added only a single timestep at a time, we now evaluate adding batches of 128 timesteps at a time - a feature that most would use in high-throughput RL. We only compare to the buffers which have this capability.\n\n<p float=\"left\">\n<img alt=\"CPU_Add_Batch\" src=\"docs/imgs/cpu_add_batch.png\" width=\"49%\">\n<img alt=\"TPU_Add_Batch\" src=\"docs/imgs/tpu_add_batch.png\" width=\"49%\">\n</p>\n\n<p align=\"center\">\n<img alt=\"GPU_Add_Batch\" src=\"docs/imgs/gpu_add_batch.png\" width=\"49%\">\n</p>\n\nUltimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called.\n\n\n## Contributing \ud83e\udd1d\n\nContributions are welcome! See our issue tracker for\n[good first issues](https://github.com/instadeepai/flashbax/labels/good%20first%20issue). Please read\nour [contributing guidelines](https://github.com/instadeepai/flashbax/blob/main/CONTRIBUTING.md) for\ndetails on how to submit pull requests, our Contributor License Agreement, and community guidelines.\n\n## See Also \ud83d\udcda\nCheckout some of the other buffer libraries from the community that we have highlighted in our\nbenchmarks.\n\n- \ud83d\udcc0 [Dejax](https://github.com/hr0nix/dejax): the first library to provide a JAX-compatible replay buffers.\n- \ud83c\udfb6 [Reverb](https://github.com/google-deepmind/reverb): efficient replay buffers used for both local and large-scale distributed RL.\n- \ud83c\udf70 [Dopamine](https://github.com/google/dopamine/blob/master/dopamine/replay_memory/): research framework for fast prototyping, providing several core replay buffers.\n- \ud83e\udd16 [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/): suite of reliable RL baselines with its own, easy-to-use replay buffers.\n\n## Citing Flashbax \u270f\ufe0f\n\nIf you use Flashbax in your work, please cite the library using:\n\n```\n@misc{flashbax,\n    title={Flashbax: Streamlining Experience Replay Buffers for Reinforcement Learning with JAX},\n    author={Edan Toledo and Laurence Midgley and Donal Byrne and Callum Rhys Tilbury and\n    Matthew Macfarlane and Cyprien Courtot and Alexandre Laterre},\n    year={2023},\n    url={https://github.com/instadeepai/flashbax/},\n}\n```\n\n## Acknowledgements \ud83d\ude4f\n\nThe development of this library was supported with Cloud TPUs\nfrom Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC) \ud83c\udf24.\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Flashbax is an experience replay library oriented around JAX. Tailored to integrate seamlessly with JAX's Just-In-Time (JIT) compilation.",
    "version": "0.1.0",
    "project_urls": {
        "Bug Tracker": "https://github.com/instadeepai/flashbax/issues",
        "Homepage": "https://github.com/instadeepai/flashbax",
        "Repository": "https://github.com/instadeepai/flashbax.git"
    },
    "split_keywords": [
        "jax",
        "memory",
        "python",
        "reinforcement-learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1ff9b28e7e84234995e8afe7b5b3d38896ce4c4022969ade96f9751be5d7473c",
                "md5": "92b66e91ff9a7600f3a82cb45586e2f8",
                "sha256": "43f39f1d013274ace4afddcfb7c39f812d23dd5be7e7454482b3b266474c393d"
            },
            "downloads": -1,
            "filename": "flashbax-0.1.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "92b66e91ff9a7600f3a82cb45586e2f8",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 72720,
            "upload_time": "2024-02-06T10:33:02",
            "upload_time_iso_8601": "2024-02-06T10:33:02.115755Z",
            "url": "https://files.pythonhosted.org/packages/1f/f9/b28e7e84234995e8afe7b5b3d38896ce4c4022969ade96f9751be5d7473c/flashbax-0.1.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7bde0b53b993ae8c2a96fe84dd8641cf9b889f19152b67d64b3913c7692dc812",
                "md5": "b91e2b2397002334e2b7e73f58b4c4fa",
                "sha256": "ee929b4896d835695654f560e74352c31ab00caa1198adc263dd1e14bead81c3"
            },
            "downloads": -1,
            "filename": "flashbax-0.1.0.tar.gz",
            "has_sig": false,
            "md5_digest": "b91e2b2397002334e2b7e73f58b4c4fa",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 51354,
            "upload_time": "2024-02-06T10:33:04",
            "upload_time_iso_8601": "2024-02-06T10:33:04.849392Z",
            "url": "https://files.pythonhosted.org/packages/7b/de/0b53b993ae8c2a96fe84dd8641cf9b889f19152b67d64b3913c7692dc812/flashbax-0.1.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-06 10:33:04",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "instadeepai",
    "github_project": "flashbax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "flashbax"
}
        
Elapsed time: 0.17579s