gato-tf


Namegato-tf JSON
Version 0.0.4 PyPI version JSON
download
home_pagehttps://github.com/OrigamiDream/gato.git
SummaryUnofficial Gato: A Generalist Agent
upload_time2023-05-26 18:27:24
maintainer
docs_urlNone
authorOrigamiDream
requires_python>=3.10.0
licenseMIT
keywords deep learning gato tensorflow generalist agent
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h1 align="center">Unofficial Gato: A Generalist Agent</h1>

[[Deepmind Publication]](https://www.deepmind.com/publications/a-generalist-agent)
[[arXiv Paper]](https://arxiv.org/pdf/2205.06175.pdf)

This repository contains Deepmind's Gato architecture imitation in TensorFlow.

Since Deepmind only mentions parts of the architecture in its paper, We still don't know much about the model.<br>
However, I believe the paper is enough to imitate the architecture, I'm trying to do that with the open source community's help.

Currently, the repository supports the following operations:
- Gato (via [`Gato`](https://github.com/OrigamiDream/gato/blob/main/gato/models/__init__.py#L12))
- Transformer (via [`Transformer`](https://github.com/OrigamiDream/gato/blob/main/gato/models/__init__.py#L61))
- Patch Position Encodings (via [`PatchPositionEncoding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L38))
- Embedding Function (via [`ResidualEmbedding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L139))
- Local Observation Position Encodings (via [`LocalPositionEncoding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L199))
- Tokenizing Continuous Values (via [`ContinuousValueTokenizer`](https://github.com/OrigamiDream/gato/blob/main/gato/models/tokenizers.py#L30))
- Shared Embedding (via [`DiscreteEmbedding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L237))

Action tokens are still a mystery in the paper, I need your help.

However, the repository lacks the following miscellaneous.
- Datasets (most important, Issue: [#1](https://github.com/OrigamiDream/gato/issues/1), [ThomasRochefortB/torch-gato](https://github.com/ThomasRochefortB/torch-gato/blob/main/datasets/README.md))
- <s>Pre-trained tokenizers</s> (No longer required because of E2E model)
- Training strategy (E2E, WIP)

But, you can still explore the basic architecture of the Gato based on the paper.

### Usage
```bash
$ pip install gato-tf
```
```python
import tensorflow as tf
from gato import Gato, GatoConfig

# Create model instance
config = GatoConfig.small()
gato = Gato(config)

# Fake inputs for Gato
input_dim = config.input_dim
input_ids = tf.concat([
  # ...
  # observation 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 0
  tf.random.uniform((1, 1, input_dim)),  # image patch 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 2
  # ...
  tf.random.uniform((1, 1, input_dim)),  # image patch 19
  tf.fill((1, 1, input_dim), value=0.25),  # continuous value
  tf.fill((1, 1, input_dim), value=624.0),  # discrete (actions, texts)

  # observation 2
  tf.random.uniform((1, 1, input_dim)),  # image patch 0
  tf.random.uniform((1, 1, input_dim)),  # image patch 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 2
  # ...
  tf.random.uniform((1, 1, input_dim)),  # image patch 19
  tf.fill((1, 1, input_dim), value=0.12),  # continuous value
  tf.fill((1, 1, input_dim), value=295.0)  # discrete (actions, texts)
  # ...
], axis=1)
encoding = tf.constant([
  # 0 - image patch embedding
  # 1 - continuous value embedding
  # 2 - discrete embedding (actions, texts)
  [0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])
row_pos = (
  tf.constant([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]),  # pos_from
  tf.constant([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]])   # pos_to
)
col_pos = (
  tf.constant([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]),  # pos_from
  tf.constant([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]])   # pos_to
)
obs = (
  tf.constant([[ 0,  1,  2, 19, 20, 21,  0,  1,  2, 19, 20, 21]]),  # obs token
  tf.constant([[ 1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0]])   # obs token masking (for action tokens)
)
hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))
```
### Dataset and Model Architecture
<picture>
  <source media="(prefers-color-scheme: dark)" srcset="https://user-images.githubusercontent.com/5837620/215323793-7f7bcfdb-d8be-40d3-8e58-a053511f95d5.png">
  <img alt="gato dataset and model architecture" src="https://user-images.githubusercontent.com/5837620/215323795-3a433516-f5ca-4272-9999-3df87ae521ba.png">
</picture>

## Paper Reviews

### Full Episode Sequence

<picture>
    <source media="(prefers-color-scheme: dark)" srcset="https://user-images.githubusercontent.com/5837620/175756389-31d183c9-054e-4829-93a6-df79781ca212.png">
    <img alt="gato dataset architecture" src="https://user-images.githubusercontent.com/5837620/175756409-75605dbc-7756-4509-ba93-c0ad08eea309.png">
</picture>

### Architecture Variants

> Appendix C.1. Transformer Hyperparameters

In the paper, Deepmind tested Gato with 3 architecture variants, `1.18B`, `364M`, and `79M`.<br>
I have named them as `large()`, `baseline()` and `small()` respectively in `GatoConfig`.

| Hyperparameters          | Large(1.18B) | Baseline(364M) | Small(79M) |
|--------------------------|--------------|----------------|------------|
| Transformer blocks       | 24           | 12             | 8          |
| Attention heads          | 16           | 12             | 24         |
| Layer width              | 2048         | 1536           | 768        |
| Feedforward hidden size  | 8192         | 6144           | 3072       |
| Key/value size           | 128          | 128            | 32         |


### Residual Embedding

> Appendix C.2. Embedding Function

There are no mentions that how many residual networks must be stacked for token embeddings.<br>
Therefore, I remain configurable in `GatoConfig`.

Whatever how many residual layers are existing, full-preactivation is a key.

The blocks are consisted of:
- Version 2 ResNet architecture (based on ResNet50V2)
- GroupNorm (instead of LayerNorm)
- GeLU (instead of ReLU)

### Position Encodings

> Appendix C.3. Position Encodings

#### Patch Position Encodings

Like [Vision Transformer (ViT)](https://github.com/google-research/vision_transformer) by Google, Gato takes the input images as raster-ordered 16x16 patches.<br>
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.

For high-performance computation in TensorFlow, I have used the following expressions.

$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean `from` and `to` respectively.

$$
\begin{align}
  v^R_F &= \begin{bmatrix}
    0 & 32 & 64 & 96
  \end{bmatrix} \\
  v^R_T &= \begin{bmatrix}
    32 & 64 & 96 & 128
  \end{bmatrix} \\
  v^C_F &= \begin{bmatrix}
    0 & 26 & 51 & 77 & 102
  \end{bmatrix} \\
  v^C_T &= \begin{bmatrix}
    26 & 51 & 77 & 102 & 128
  \end{bmatrix} \\
  \\
  P_R &= \begin{cases}
    \mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \\
    \mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2})
  \end{cases} \\
  P_C &= \begin{cases}
    \mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \\
    \mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2})
  \end{cases} \\
  \\
  E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \\
  E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \\
  \\
  \therefore E &= E_I + E^R_P + E^C_P
\end{align}
$$

#### Local Observation Position Encodings

In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens<br>
When Gato receives those values, they must be encoded with their own (local) time steps.

## Requirements

```bash
pip install tensorflow>=2.11.0
```

## Contributing

This repository is still a work in progress.<br>
Currently, no downloads and no executables are provided.

I welcome many contributors who can help.

## License
Licensed under the [MIT license](https://github.com/OrigamiDream/gato/blob/main/LICENSE).

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/OrigamiDream/gato.git",
    "name": "gato-tf",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.10.0",
    "maintainer_email": "",
    "keywords": "deep learning,gato,tensorflow,generalist agent",
    "author": "OrigamiDream",
    "author_email": "hello@origamidream.me",
    "download_url": "https://files.pythonhosted.org/packages/a6/8d/bf98e8af2dcc809db898e1e4545a38b3200178ade4808565b30a58d00e41/gato-tf-0.0.4.tar.gz",
    "platform": null,
    "description": "<h1 align=\"center\">Unofficial Gato: A Generalist Agent</h1>\n\n[[Deepmind Publication]](https://www.deepmind.com/publications/a-generalist-agent)\n[[arXiv Paper]](https://arxiv.org/pdf/2205.06175.pdf)\n\nThis repository contains Deepmind's Gato architecture imitation in TensorFlow.\n\nSince Deepmind only mentions parts of the architecture in its paper, We still don't know much about the model.<br>\nHowever, I believe the paper is enough to imitate the architecture, I'm trying to do that with the open source community's help.\n\nCurrently, the repository supports the following operations:\n- Gato (via [`Gato`](https://github.com/OrigamiDream/gato/blob/main/gato/models/__init__.py#L12))\n- Transformer (via [`Transformer`](https://github.com/OrigamiDream/gato/blob/main/gato/models/__init__.py#L61))\n- Patch Position Encodings (via [`PatchPositionEncoding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L38))\n- Embedding Function (via [`ResidualEmbedding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L139))\n- Local Observation Position Encodings (via [`LocalPositionEncoding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L199))\n- Tokenizing Continuous Values (via [`ContinuousValueTokenizer`](https://github.com/OrigamiDream/gato/blob/main/gato/models/tokenizers.py#L30))\n- Shared Embedding (via [`DiscreteEmbedding`](https://github.com/OrigamiDream/gato/blob/main/gato/models/embedding.py#L237))\n\nAction tokens are still a mystery in the paper, I need your help.\n\nHowever, the repository lacks the following miscellaneous.\n- Datasets (most important, Issue: [#1](https://github.com/OrigamiDream/gato/issues/1), [ThomasRochefortB/torch-gato](https://github.com/ThomasRochefortB/torch-gato/blob/main/datasets/README.md))\n- <s>Pre-trained tokenizers</s> (No longer required because of E2E model)\n- Training strategy (E2E, WIP)\n\nBut, you can still explore the basic architecture of the Gato based on the paper.\n\n### Usage\n```bash\n$ pip install gato-tf\n```\n```python\nimport tensorflow as tf\nfrom gato import Gato, GatoConfig\n\n# Create model instance\nconfig = GatoConfig.small()\ngato = Gato(config)\n\n# Fake inputs for Gato\ninput_dim = config.input_dim\ninput_ids = tf.concat([\n  # ...\n  # observation 1\n  tf.random.uniform((1, 1, input_dim)),  # image patch 0\n  tf.random.uniform((1, 1, input_dim)),  # image patch 1\n  tf.random.uniform((1, 1, input_dim)),  # image patch 2\n  # ...\n  tf.random.uniform((1, 1, input_dim)),  # image patch 19\n  tf.fill((1, 1, input_dim), value=0.25),  # continuous value\n  tf.fill((1, 1, input_dim), value=624.0),  # discrete (actions, texts)\n\n  # observation 2\n  tf.random.uniform((1, 1, input_dim)),  # image patch 0\n  tf.random.uniform((1, 1, input_dim)),  # image patch 1\n  tf.random.uniform((1, 1, input_dim)),  # image patch 2\n  # ...\n  tf.random.uniform((1, 1, input_dim)),  # image patch 19\n  tf.fill((1, 1, input_dim), value=0.12),  # continuous value\n  tf.fill((1, 1, input_dim), value=295.0)  # discrete (actions, texts)\n  # ...\n], axis=1)\nencoding = tf.constant([\n  # 0 - image patch embedding\n  # 1 - continuous value embedding\n  # 2 - discrete embedding (actions, texts)\n  [0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]\n])\nrow_pos = (\n  tf.constant([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]),  # pos_from\n  tf.constant([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]])   # pos_to\n)\ncol_pos = (\n  tf.constant([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]),  # pos_from\n  tf.constant([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]])   # pos_to\n)\nobs = (\n  tf.constant([[ 0,  1,  2, 19, 20, 21,  0,  1,  2, 19, 20, 21]]),  # obs token\n  tf.constant([[ 1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0]])   # obs token masking (for action tokens)\n)\nhidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))\n```\n### Dataset and Model Architecture\n<picture>\n  <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://user-images.githubusercontent.com/5837620/215323793-7f7bcfdb-d8be-40d3-8e58-a053511f95d5.png\">\n  <img alt=\"gato dataset and model architecture\" src=\"https://user-images.githubusercontent.com/5837620/215323795-3a433516-f5ca-4272-9999-3df87ae521ba.png\">\n</picture>\n\n## Paper Reviews\n\n### Full Episode Sequence\n\n<picture>\n    <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://user-images.githubusercontent.com/5837620/175756389-31d183c9-054e-4829-93a6-df79781ca212.png\">\n    <img alt=\"gato dataset architecture\" src=\"https://user-images.githubusercontent.com/5837620/175756409-75605dbc-7756-4509-ba93-c0ad08eea309.png\">\n</picture>\n\n### Architecture Variants\n\n> Appendix C.1. Transformer Hyperparameters\n\nIn the paper, Deepmind tested Gato with 3 architecture variants, `1.18B`, `364M`, and `79M`.<br>\nI have named them as `large()`, `baseline()` and `small()` respectively in `GatoConfig`.\n\n| Hyperparameters          | Large(1.18B) | Baseline(364M) | Small(79M) |\n|--------------------------|--------------|----------------|------------|\n| Transformer blocks       | 24           | 12             | 8          |\n| Attention heads          | 16           | 12             | 24         |\n| Layer width              | 2048         | 1536           | 768        |\n| Feedforward hidden size  | 8192         | 6144           | 3072       |\n| Key/value size           | 128          | 128            | 32         |\n\n\n### Residual Embedding\n\n> Appendix C.2. Embedding Function\n\nThere are no mentions that how many residual networks must be stacked for token embeddings.<br>\nTherefore, I remain configurable in `GatoConfig`.\n\nWhatever how many residual layers are existing, full-preactivation is a key.\n\nThe blocks are consisted of:\n- Version 2 ResNet architecture (based on ResNet50V2)\n- GroupNorm (instead of LayerNorm)\n- GeLU (instead of ReLU)\n\n### Position Encodings\n\n> Appendix C.3. Position Encodings\n\n#### Patch Position Encodings\n\nLike [Vision Transformer (ViT)](https://github.com/google-research/vision_transformer) by Google, Gato takes the input images as raster-ordered 16x16 patches.<br>\nUnlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.\n\nFor high-performance computation in TensorFlow, I have used the following expressions.\n\n$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean `from` and `to` respectively.\n\n$$\n\\begin{align}\n  v^R_F &= \\begin{bmatrix}\n    0 & 32 & 64 & 96\n  \\end{bmatrix} \\\\\n  v^R_T &= \\begin{bmatrix}\n    32 & 64 & 96 & 128\n  \\end{bmatrix} \\\\\n  v^C_F &= \\begin{bmatrix}\n    0 & 26 & 51 & 77 & 102\n  \\end{bmatrix} \\\\\n  v^C_T &= \\begin{bmatrix}\n    26 & 51 & 77 & 102 & 128\n  \\end{bmatrix} \\\\\n  \\\\\n  P_R &= \\begin{cases}\n    \\mathsf{if} \\ \\mathsf{training} & v^R_F + \\mathsf{uniform}(v^R_T - v^R_F) \\\\\n    \\mathsf{otherwise} & \\mathsf{round}(\\frac{v^R_F + v^R_T}{2})\n  \\end{cases} \\\\\n  P_C &= \\begin{cases}\n    \\mathsf{if} \\ \\mathsf{training} & v^C_F + \\mathsf{uniform}(v^C_T - v^C_F) \\\\\n    \\mathsf{otherwise} & \\mathsf{round}(\\frac{v^C_F + v^C_T}{2})\n  \\end{cases} \\\\\n  \\\\\n  E^R_P &= P_R \\cdot 1^{\\mathsf{T}}_C \\\\\n  E^C_P &= 1^{\\mathsf{T}}_R \\cdot P_C \\\\\n  \\\\\n  \\therefore E &= E_I + E^R_P + E^C_P\n\\end{align}\n$$\n\n#### Local Observation Position Encodings\n\nIn the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens<br>\nWhen Gato receives those values, they must be encoded with their own (local) time steps.\n\n## Requirements\n\n```bash\npip install tensorflow>=2.11.0\n```\n\n## Contributing\n\nThis repository is still a work in progress.<br>\nCurrently, no downloads and no executables are provided.\n\nI welcome many contributors who can help.\n\n## License\nLicensed under the [MIT license](https://github.com/OrigamiDream/gato/blob/main/LICENSE).\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Unofficial Gato: A Generalist Agent",
    "version": "0.0.4",
    "project_urls": {
        "Homepage": "https://github.com/OrigamiDream/gato.git"
    },
    "split_keywords": [
        "deep learning",
        "gato",
        "tensorflow",
        "generalist agent"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "b76f8c7b6e986ae636d8bf37e1aae8133e5721601281e01434ea68bf99a0cc11",
                "md5": "cb1618bc45ce3bdcf6de23055f62b210",
                "sha256": "8dadb41da9fe748721d750a5372d69f58b8a07e2123620372ff935d51c5e28c0"
            },
            "downloads": -1,
            "filename": "gato_tf-0.0.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "cb1618bc45ce3bdcf6de23055f62b210",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10.0",
            "size": 11859,
            "upload_time": "2023-05-26T18:27:22",
            "upload_time_iso_8601": "2023-05-26T18:27:22.690025Z",
            "url": "https://files.pythonhosted.org/packages/b7/6f/8c7b6e986ae636d8bf37e1aae8133e5721601281e01434ea68bf99a0cc11/gato_tf-0.0.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a68dbf98e8af2dcc809db898e1e4545a38b3200178ade4808565b30a58d00e41",
                "md5": "8c1f9e86d8a5c3761faea8b45dbbbab6",
                "sha256": "2f3954448df6c32be79cf4da587fae2e8a9626c89db3cdea00b94bf94d0fdaa6"
            },
            "downloads": -1,
            "filename": "gato-tf-0.0.4.tar.gz",
            "has_sig": false,
            "md5_digest": "8c1f9e86d8a5c3761faea8b45dbbbab6",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10.0",
            "size": 12394,
            "upload_time": "2023-05-26T18:27:24",
            "upload_time_iso_8601": "2023-05-26T18:27:24.140661Z",
            "url": "https://files.pythonhosted.org/packages/a6/8d/bf98e8af2dcc809db898e1e4545a38b3200178ade4808565b30a58d00e41/gato-tf-0.0.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-05-26 18:27:24",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "OrigamiDream",
    "github_project": "gato",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "gato-tf"
}
        
Elapsed time: 0.17734s