mctx


Namemctx JSON
Version 0.0.5 PyPI version JSON
download
home_pagehttps://github.com/google-deepmind/mctx
SummaryMonte Carlo tree search in JAX.
upload_time2023-11-24 11:52:32
maintainer
docs_urlNone
authorDeepMind
requires_python>=3.9
licenseApache 2.0
keywords jax planning reinforcement-learning python machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Mctx: MCTS-in-JAX

Mctx is a library with a [JAX](https://github.com/google/jax)-native
implementation of Monte Carlo tree search (MCTS) algorithms such as
[AlphaZero](https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go),
[MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules), and
[Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO). For computation
speed up, the implementation fully supports JIT-compilation. Search algorithms
in Mctx are defined for and operate on batches of inputs, in parallel. This
allows to make the most of the accelerators and enables the algorithms to work
with large learned environment models parameterized by deep neural networks.

## Installation

You can install the latest released version of Mctx from PyPI via:

```sh
pip install mctx
```

or you can install the latest development version from GitHub:

```sh
pip install git+https://github.com/google-deepmind/mctx.git
```

## Motivation

Learning and search have been important topics since the early days of AI
research. In the [words of Rich Sutton](http://www.incompleteideas.net/IncIdeas/BitterLesson.html):

> One thing that should be learned [...] is the great power of general purpose
> methods, of methods that continue to scale with increased computation even as
> the available computation becomes very great. The two methods that seem to
> scale arbitrarily in this way are *search* and *learning*.

Recently, search algorithms have been successfully combined with learned models
parameterized by deep neural networks, resulting in some of the most powerful
and general reinforcement learning algorithms to date (e.g. MuZero).
However, using search algorithms in combination with deep neural networks
requires efficient implementations, typically written in fast compiled
languages; this can come at the expense of usability and hackability,
especially for researchers that are not familiar with C++. In turn, this limits
adoption and further research on this critical topic.

Through this library, we hope to help researchers everywhere to contribute to
such an exciting area of research. We provide JAX-native implementations of core
search algorithms such as MCTS, that we believe strike a good balance between
performance and usability for researchers that want to investigate search-based
algorithms in Python. The search methods provided by Mctx are
heavily configurable to allow researchers to explore a variety of ideas in
this space, and contribute to the next generation of search based agents.

## Search in Reinforcement Learning

In Reinforcement Learning the *agent* must learn to interact with the
*environment* in order to maximize a scalar *reward* signal. On each step the
agent must select an action and receives in exchange an observation and a
reward. We may call whatever mechanism the agent uses to select the action the
agent's *policy*.

Classically, policies are parameterized directly by a function approximator (as
in REINFORCE), or policies are inferred by inspecting a set of learned estimates
of the value of each action (as in Q-learning). Alternatively, search allows to
select actions by constructing on the fly, in each state, a policy or a value
function local to the current state, by *searching* using a learned *model* of
the environment.

Exhaustive search over all possible future courses of actions is computationally
prohibitive in any non trivial environment, hence we need search algorithms
that can make the best use of a finite computational budget. Typically priors
are needed to guide which nodes in the search tree to expand (to reduce the
*breadth* of the tree that we construct), and value functions are used to
estimate the value of incomplete paths in the tree that don't reach an episode
termination (to reduce the *depth* of the search tree).

## Quickstart

Mctx provides a low-level generic `search` function and high-level concrete
policies: `muzero_policy` and `gumbel_muzero_policy`.

The user needs to provide several learned components to specify the
representation, dynamics and prediction used by [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules).
In the context of the Mctx library, the representation of the *root* state is
specified by a `RootFnOutput`. The `RootFnOutput` contains the `prior_logits`
from a policy network, the estimated `value` of the root state, and any
`embedding` suitable to represent the root state for the environment model.

The dynamics environment model needs to be specified by a `recurrent_fn`.
A `recurrent_fn(params, rng_key, action, embedding)` call takes an `action` and
a state `embedding`. The call should return a tuple `(recurrent_fn_output,
new_embedding)` with a `RecurrentFnOutput` and the embedding of the next state.
The `RecurrentFnOutput` contains the `reward` and `discount` for the transition,
and `prior_logits` and `value` for the new state.

In [`examples/visualization_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/visualization_demo.py), you can
see calls to a policy:

```python
policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,
                                          num_simulations=32)
```

The `policy_output.action` contains the action proposed by the search. That
action can be passed to the environment. To improve the policy, the
`policy_output.action_weights` contain targets usable to train the policy
probabilities.

We recommend to use the `gumbel_muzero_policy`.
[Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO) guarantees a policy
improvement if the action values are correctly evaluated. The policy improvement
is demonstrated in
[`examples/policy_improvement_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/policy_improvement_demo.py).

### Example projects
The following projects demonstrate the Mctx usage:

- [Pgx](https://github.com/sotetsuk/pgx) — A collection of 20+ vectorized
  JAX environments, including backgammon, chess, shogi, Go, and an AlphaZero
  example.
- [Basic Learning Demo with Mctx](https://github.com/kenjyoung/mctx_learning_demo) —
  AlphaZero on random mazes.
- [a0-jax](https://github.com/NTT123/a0-jax) — AlphaZero on Connect Four,
  Gomoku, and Go.
- [muax](https://github.com/bwfbowen/muax) — MuZero on gym-style environments
(CartPole, LunarLander).
- [Classic MCTS](https://github.com/Carbon225/mctx-classic) — A simple example on Connect Four.

Tell us about your project.

## Citing Mctx

This repository is part of the DeepMind JAX Ecosystem, to cite Mctx
please use the citation:

```bibtex
@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/google-deepmind/mctx",
    "name": "mctx",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "jax planning reinforcement-learning python machine learning",
    "author": "DeepMind",
    "author_email": "mctx-dev@google.com",
    "download_url": "https://files.pythonhosted.org/packages/77/03/90f4c3cd7c729422c4667fa4a6c74d8b105cae0e5fade1e7c253f11bb60c/mctx-0.0.5.tar.gz",
    "platform": null,
    "description": "# Mctx: MCTS-in-JAX\n\nMctx is a library with a [JAX](https://github.com/google/jax)-native\nimplementation of Monte Carlo tree search (MCTS) algorithms such as\n[AlphaZero](https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go),\n[MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules), and\n[Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO). For computation\nspeed up, the implementation fully supports JIT-compilation. Search algorithms\nin Mctx are defined for and operate on batches of inputs, in parallel. This\nallows to make the most of the accelerators and enables the algorithms to work\nwith large learned environment models parameterized by deep neural networks.\n\n## Installation\n\nYou can install the latest released version of Mctx from PyPI via:\n\n```sh\npip install mctx\n```\n\nor you can install the latest development version from GitHub:\n\n```sh\npip install git+https://github.com/google-deepmind/mctx.git\n```\n\n## Motivation\n\nLearning and search have been important topics since the early days of AI\nresearch. In the [words of Rich Sutton](http://www.incompleteideas.net/IncIdeas/BitterLesson.html):\n\n> One thing that should be learned [...] is the great power of general purpose\n> methods, of methods that continue to scale with increased computation even as\n> the available computation becomes very great. The two methods that seem to\n> scale arbitrarily in this way are *search* and *learning*.\n\nRecently, search algorithms have been successfully combined with learned models\nparameterized by deep neural networks, resulting in some of the most powerful\nand general reinforcement learning algorithms to date (e.g. MuZero).\nHowever, using search algorithms in combination with deep neural networks\nrequires efficient implementations, typically written in fast compiled\nlanguages; this can come at the expense of usability and hackability,\nespecially for researchers that are not familiar with C++. In turn, this limits\nadoption and further research on this critical topic.\n\nThrough this library, we hope to help researchers everywhere to contribute to\nsuch an exciting area of research. We provide JAX-native implementations of core\nsearch algorithms such as MCTS, that we believe strike a good balance between\nperformance and usability for researchers that want to investigate search-based\nalgorithms in Python. The search methods provided by Mctx are\nheavily configurable to allow researchers to explore a variety of ideas in\nthis space, and contribute to the next generation of search based agents.\n\n## Search in Reinforcement Learning\n\nIn Reinforcement Learning the *agent* must learn to interact with the\n*environment* in order to maximize a scalar *reward* signal. On each step the\nagent must select an action and receives in exchange an observation and a\nreward. We may call whatever mechanism the agent uses to select the action the\nagent's *policy*.\n\nClassically, policies are parameterized directly by a function approximator (as\nin REINFORCE), or policies are inferred by inspecting a set of learned estimates\nof the value of each action (as in Q-learning). Alternatively, search allows to\nselect actions by constructing on the fly, in each state, a policy or a value\nfunction local to the current state, by *searching* using a learned *model* of\nthe environment.\n\nExhaustive search over all possible future courses of actions is computationally\nprohibitive in any non trivial environment, hence we need search algorithms\nthat can make the best use of a finite computational budget. Typically priors\nare needed to guide which nodes in the search tree to expand (to reduce the\n*breadth* of the tree that we construct), and value functions are used to\nestimate the value of incomplete paths in the tree that don't reach an episode\ntermination (to reduce the *depth* of the search tree).\n\n## Quickstart\n\nMctx provides a low-level generic `search` function and high-level concrete\npolicies: `muzero_policy` and `gumbel_muzero_policy`.\n\nThe user needs to provide several learned components to specify the\nrepresentation, dynamics and prediction used by [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules).\nIn the context of the Mctx library, the representation of the *root* state is\nspecified by a `RootFnOutput`. The `RootFnOutput` contains the `prior_logits`\nfrom a policy network, the estimated `value` of the root state, and any\n`embedding` suitable to represent the root state for the environment model.\n\nThe dynamics environment model needs to be specified by a `recurrent_fn`.\nA `recurrent_fn(params, rng_key, action, embedding)` call takes an `action` and\na state `embedding`. The call should return a tuple `(recurrent_fn_output,\nnew_embedding)` with a `RecurrentFnOutput` and the embedding of the next state.\nThe `RecurrentFnOutput` contains the `reward` and `discount` for the transition,\nand `prior_logits` and `value` for the new state.\n\nIn [`examples/visualization_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/visualization_demo.py), you can\nsee calls to a policy:\n\n```python\npolicy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,\n                                          num_simulations=32)\n```\n\nThe `policy_output.action` contains the action proposed by the search. That\naction can be passed to the environment. To improve the policy, the\n`policy_output.action_weights` contain targets usable to train the policy\nprobabilities.\n\nWe recommend to use the `gumbel_muzero_policy`.\n[Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO) guarantees a policy\nimprovement if the action values are correctly evaluated. The policy improvement\nis demonstrated in\n[`examples/policy_improvement_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/policy_improvement_demo.py).\n\n### Example projects\nThe following projects demonstrate the Mctx usage:\n\n- [Pgx](https://github.com/sotetsuk/pgx) \u2014 A collection of 20+ vectorized\n  JAX environments, including backgammon, chess, shogi, Go, and an AlphaZero\n  example.\n- [Basic Learning Demo with Mctx](https://github.com/kenjyoung/mctx_learning_demo) \u2014\n  AlphaZero on random mazes.\n- [a0-jax](https://github.com/NTT123/a0-jax) \u2014 AlphaZero on Connect Four,\n  Gomoku, and Go.\n- [muax](https://github.com/bwfbowen/muax) \u2014 MuZero on gym-style environments\n(CartPole, LunarLander).\n- [Classic MCTS](https://github.com/Carbon225/mctx-classic) \u2014 A simple example on Connect Four.\n\nTell us about your project.\n\n## Citing Mctx\n\nThis repository is part of the DeepMind JAX Ecosystem, to cite Mctx\nplease use the citation:\n\n```bibtex\n@software{deepmind2020jax,\n  title = {The {D}eep{M}ind {JAX} {E}cosystem},\n  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\\'{c}, Milo\\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},\n  url = {http://github.com/deepmind},\n  year = {2020},\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Monte Carlo tree search in JAX.",
    "version": "0.0.5",
    "project_urls": {
        "Homepage": "https://github.com/google-deepmind/mctx"
    },
    "split_keywords": [
        "jax",
        "planning",
        "reinforcement-learning",
        "python",
        "machine",
        "learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "570409a8cf47742ee18c8596ec86f0b4558548bf0ffaf0eb83506029bd9750e2",
                "md5": "de97051bf8b61639afe89aac8ac9925e",
                "sha256": "d263830a1e44a16fe2ede5ed5b37fd83626459bfccbb1ab814a6bd49bf62ffd3"
            },
            "downloads": -1,
            "filename": "mctx-0.0.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "de97051bf8b61639afe89aac8ac9925e",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 45285,
            "upload_time": "2023-11-24T11:52:30",
            "upload_time_iso_8601": "2023-11-24T11:52:30.703715Z",
            "url": "https://files.pythonhosted.org/packages/57/04/09a8cf47742ee18c8596ec86f0b4558548bf0ffaf0eb83506029bd9750e2/mctx-0.0.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "770390f4c3cd7c729422c4667fa4a6c74d8b105cae0e5fade1e7c253f11bb60c",
                "md5": "26316e87717df35dcd89bb1760820291",
                "sha256": "e9f669bf4fd4c4f61837be6f9ab0ca60180945108c36bcdf5beaabc481020e21"
            },
            "downloads": -1,
            "filename": "mctx-0.0.5.tar.gz",
            "has_sig": false,
            "md5_digest": "26316e87717df35dcd89bb1760820291",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 36874,
            "upload_time": "2023-11-24T11:52:32",
            "upload_time_iso_8601": "2023-11-24T11:52:32.340731Z",
            "url": "https://files.pythonhosted.org/packages/77/03/90f4c3cd7c729422c4667fa4a6c74d8b105cae0e5fade1e7c253f11bb60c/mctx-0.0.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-11-24 11:52:32",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google-deepmind",
    "github_project": "mctx",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "mctx"
}
        
Elapsed time: 0.15509s