rlax


Namerlax JSON
Version 0.1.6 PyPI version JSON
download
home_pagehttps://github.com/deepmind/rlax
SummaryA library of reinforcement learning building blocks in JAX.
upload_time2023-06-29 15:03:36
maintainer
docs_urlNone
authorDeepMind
requires_python>=3.9
licenseApache 2.0
keywords reinforcement-learning python machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # RLax

![CI status](https://github.com/deepmind/rlax/workflows/ci/badge.svg)
![docs](https://readthedocs.org/projects/rlax/badge/?version=latest)
![pypi](https://img.shields.io/pypi/v/rlax)

RLax (pronounced "relax") is a library built on top of JAX that exposes
useful building blocks for implementing reinforcement learning agents. Full
documentation can be found at
 [rlax.readthedocs.io](https://rlax.readthedocs.io/en/latest/index.html).

## Installation

You can install the latest released version of RLax from PyPI via:

```sh
pip install rlax
```

or you can install the latest development version from GitHub:

```sh
pip install git+https://github.com/deepmind/rlax.git
```

All RLax code may then be just in time compiled for different hardware
(e.g. CPU, GPU, TPU) using `jax.jit`.

In order to run the `examples/` you will also need to clone the repo and
install the additional requirements:
[optax](https://github.com/deepmind/optax),
[haiku](https://github.com/deepmind/haiku), and
[bsuite](https://github.com/deepmind/bsuite).

## Content

The operations and functions provided are not complete algorithms, but
implementations of reinforcement learning specific mathematical operations that
are needed when building fully-functional agents capable of learning:

* Values, including both state and action-values;
* Values for Non-linear generalizations of the Bellman equations.
* Return Distributions, aka distributional value functions;
* General Value Functions, for cumulants other than the main reward;
* Policies, via policy-gradients in both continuous and discrete action spaces.

The library supports both on-policy and off-policy learning (i.e. learning from
data sampled from a policy different from the agent's policy).

See file-level and function-level doc-strings for the documentation of these
functions and for references to the papers that introduced and/or used them.

## Usage

See `examples/` for examples of using some of the functions in RLax to
implement a few simple reinforcement learning agents, and demonstrate learning
on BSuite's version of the Catch environment (a common unit-test for
agent development in the reinforcement learning literature):

Other examples of JAX reinforcement learning agents using `rlax` can be found in
[bsuite](https://github.com/deepmind/bsuite/tree/master/bsuite/baselines).

## Background

Reinforcement learning studies the problem of a learning system (the *agent*),
which must learn to interact with the universe it is embedded in (the
*environment*).

Agent and environment interact on discrete steps. On each step the agent selects
an *action*, and is provided in return a (partial) snapshot of the state of the
environment (the *observation*), and a scalar feedback signal (the *reward*).

The behaviour of the agent is characterized by a probability distribution over
actions, conditioned on past observations of the environment (the *policy*). The
agents seeks a policy that, from any given step, maximises the discounted
cumulative reward that will be collected from that point onwards (the *return*).

Often the agent policy or the environment dynamics itself are stochastic. In
this case the return is a random variable, and the optimal agent's policy is
typically more precisely specified as a policy that maximises the expectation of
the return (the *value*), under the agent's and environment's stochasticity.

## Reinforcement Learning Algorithms

There are three prototypical families of reinforcement learning algorithms:

1.  those that estimate the value of states and actions, and infer a policy by
    *inspection* (e.g. by selecting the action with highest estimated value)
2.  those that learn a model of the environment (capable of predicting the
    observations and rewards) and infer a policy via *planning*.
3.  those that parameterize a policy that can be directly *executed*,

In any case, policies, values or models are just functions. In deep
reinforcement learning such functions are represented by a neural network.
In this setting, it is common to formulate reinforcement learning updates as
differentiable pseudo-loss functions (analogously to (un-)supervised learning).
Under automatic differentiation, the original update rule is recovered.

Note however, that in particular, the updates are only valid if the input data
is sampled in the correct manner. For example, a policy gradient loss is only
valid if the input trajectory is an unbiased sample from the current policy;
i.e. the data are on-policy. The library cannot check or enforce such
constraints. Links to papers describing how each operation is used are however
provided in the functions' doc-strings.

## Naming Conventions and Developer Guidelines

We define functions and operations for agents interacting with a single stream
of experience. The JAX construct `vmap` can be used to apply these same
functions to batches (e.g. to support *replay* and *parallel* data generation).

Many functions consider policies, actions, rewards, values, in consecutive
timesteps in order to compute their outputs. In this case the suffix `_t` and
`tm1` is often to clarify on which step each input was generated, e.g:

*   `q_tm1`: the action value in the `source` state of a transition.
*   `a_tm1`: the action that was selected in the `source` state.
*   `r_t`: the resulting rewards collected in the `destination` state.
*   `discount_t`: the `discount` associated with a transition.
*   `q_t`: the action values in the `destination` state.

Extensive testing is provided for each function. All tests should also verify
the output of `rlax` functions when compiled to XLA using `jax.jit` and when
performing batch operations using `jax.vmap`.

## Citing RLax

RLax is part of the [DeepMind JAX Ecosystem], to cite RLax please use
the [DeepMind JAX Ecosystem citation].

[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
[DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation"


            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/deepmind/rlax",
    "name": "rlax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "reinforcement-learning python machine learning",
    "author": "DeepMind",
    "author_email": "rlax-dev@google.com",
    "download_url": "https://files.pythonhosted.org/packages/5e/ba/2c70f9eaf40b3955876616baacb1eb5879904c0de82504a67f1c48e5d15f/rlax-0.1.6.tar.gz",
    "platform": null,
    "description": "# RLax\n\n![CI status](https://github.com/deepmind/rlax/workflows/ci/badge.svg)\n![docs](https://readthedocs.org/projects/rlax/badge/?version=latest)\n![pypi](https://img.shields.io/pypi/v/rlax)\n\nRLax (pronounced \"relax\") is a library built on top of JAX that exposes\nuseful building blocks for implementing reinforcement learning agents. Full\ndocumentation can be found at\n [rlax.readthedocs.io](https://rlax.readthedocs.io/en/latest/index.html).\n\n## Installation\n\nYou can install the latest released version of RLax from PyPI via:\n\n```sh\npip install rlax\n```\n\nor you can install the latest development version from GitHub:\n\n```sh\npip install git+https://github.com/deepmind/rlax.git\n```\n\nAll RLax code may then be just in time compiled for different hardware\n(e.g. CPU, GPU, TPU) using `jax.jit`.\n\nIn order to run the `examples/` you will also need to clone the repo and\ninstall the additional requirements:\n[optax](https://github.com/deepmind/optax),\n[haiku](https://github.com/deepmind/haiku), and\n[bsuite](https://github.com/deepmind/bsuite).\n\n## Content\n\nThe operations and functions provided are not complete algorithms, but\nimplementations of reinforcement learning specific mathematical operations that\nare needed when building fully-functional agents capable of learning:\n\n* Values, including both state and action-values;\n* Values for Non-linear generalizations of the Bellman equations.\n* Return Distributions, aka distributional value functions;\n* General Value Functions, for cumulants other than the main reward;\n* Policies, via policy-gradients in both continuous and discrete action spaces.\n\nThe library supports both on-policy and off-policy learning (i.e. learning from\ndata sampled from a policy different from the agent's policy).\n\nSee file-level and function-level doc-strings for the documentation of these\nfunctions and for references to the papers that introduced and/or used them.\n\n## Usage\n\nSee `examples/` for examples of using some of the functions in RLax to\nimplement a few simple reinforcement learning agents, and demonstrate learning\non BSuite's version of the Catch environment (a common unit-test for\nagent development in the reinforcement learning literature):\n\nOther examples of JAX reinforcement learning agents using `rlax` can be found in\n[bsuite](https://github.com/deepmind/bsuite/tree/master/bsuite/baselines).\n\n## Background\n\nReinforcement learning studies the problem of a learning system (the *agent*),\nwhich must learn to interact with the universe it is embedded in (the\n*environment*).\n\nAgent and environment interact on discrete steps. On each step the agent selects\nan *action*, and is provided in return a (partial) snapshot of the state of the\nenvironment (the *observation*), and a scalar feedback signal (the *reward*).\n\nThe behaviour of the agent is characterized by a probability distribution over\nactions, conditioned on past observations of the environment (the *policy*). The\nagents seeks a policy that, from any given step, maximises the discounted\ncumulative reward that will be collected from that point onwards (the *return*).\n\nOften the agent policy or the environment dynamics itself are stochastic. In\nthis case the return is a random variable, and the optimal agent's policy is\ntypically more precisely specified as a policy that maximises the expectation of\nthe return (the *value*), under the agent's and environment's stochasticity.\n\n## Reinforcement Learning Algorithms\n\nThere are three prototypical families of reinforcement learning algorithms:\n\n1.  those that estimate the value of states and actions, and infer a policy by\n    *inspection* (e.g. by selecting the action with highest estimated value)\n2.  those that learn a model of the environment (capable of predicting the\n    observations and rewards) and infer a policy via *planning*.\n3.  those that parameterize a policy that can be directly *executed*,\n\nIn any case, policies, values or models are just functions. In deep\nreinforcement learning such functions are represented by a neural network.\nIn this setting, it is common to formulate reinforcement learning updates as\ndifferentiable pseudo-loss functions (analogously to (un-)supervised learning).\nUnder automatic differentiation, the original update rule is recovered.\n\nNote however, that in particular, the updates are only valid if the input data\nis sampled in the correct manner. For example, a policy gradient loss is only\nvalid if the input trajectory is an unbiased sample from the current policy;\ni.e. the data are on-policy. The library cannot check or enforce such\nconstraints. Links to papers describing how each operation is used are however\nprovided in the functions' doc-strings.\n\n## Naming Conventions and Developer Guidelines\n\nWe define functions and operations for agents interacting with a single stream\nof experience. The JAX construct `vmap` can be used to apply these same\nfunctions to batches (e.g. to support *replay* and *parallel* data generation).\n\nMany functions consider policies, actions, rewards, values, in consecutive\ntimesteps in order to compute their outputs. In this case the suffix `_t` and\n`tm1` is often to clarify on which step each input was generated, e.g:\n\n*   `q_tm1`: the action value in the `source` state of a transition.\n*   `a_tm1`: the action that was selected in the `source` state.\n*   `r_t`: the resulting rewards collected in the `destination` state.\n*   `discount_t`: the `discount` associated with a transition.\n*   `q_t`: the action values in the `destination` state.\n\nExtensive testing is provided for each function. All tests should also verify\nthe output of `rlax` functions when compiled to XLA using `jax.jit` and when\nperforming batch operations using `jax.vmap`.\n\n## Citing RLax\n\nRLax is part of the [DeepMind JAX Ecosystem], to cite RLax please use\nthe [DeepMind JAX Ecosystem citation].\n\n[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research \"DeepMind JAX Ecosystem\"\n[DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt \"Citation\"\n\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "A library of reinforcement learning building blocks in JAX.",
    "version": "0.1.6",
    "project_urls": {
        "Homepage": "https://github.com/deepmind/rlax"
    },
    "split_keywords": [
        "reinforcement-learning",
        "python",
        "machine",
        "learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "9af19ed176a3eae715bd362b62ac13c99164d12f2da41dfa23962c444d814607",
                "md5": "81296b395cad6a13b671221dc1fb52bd",
                "sha256": "a22fd6bcdd5d2fff17850817dac6fbbaaa0d687aec5ea68e9277a172705c91bc"
            },
            "downloads": -1,
            "filename": "rlax-0.1.6-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "81296b395cad6a13b671221dc1fb52bd",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 131655,
            "upload_time": "2023-06-29T15:03:34",
            "upload_time_iso_8601": "2023-06-29T15:03:34.955366Z",
            "url": "https://files.pythonhosted.org/packages/9a/f1/9ed176a3eae715bd362b62ac13c99164d12f2da41dfa23962c444d814607/rlax-0.1.6-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "5eba2c70f9eaf40b3955876616baacb1eb5879904c0de82504a67f1c48e5d15f",
                "md5": "9fc0202320c852c6faf3f86d8a1871e2",
                "sha256": "0b79c53afff3c6f028cfa599d110197dfcc46f1231fd7ac669b3b840fc6b8a4f"
            },
            "downloads": -1,
            "filename": "rlax-0.1.6.tar.gz",
            "has_sig": false,
            "md5_digest": "9fc0202320c852c6faf3f86d8a1871e2",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 92646,
            "upload_time": "2023-06-29T15:03:36",
            "upload_time_iso_8601": "2023-06-29T15:03:36.600678Z",
            "url": "https://files.pythonhosted.org/packages/5e/ba/2c70f9eaf40b3955876616baacb1eb5879904c0de82504a67f1c48e5d15f/rlax-0.1.6.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-06-29 15:03:36",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "deepmind",
    "github_project": "rlax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "rlax"
}
        
Elapsed time: 0.10310s