[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)
<h1 align="center">Gato: A Generalist Agent</h1>
[[Deepmind Publication]](https://www.deepmind.com/publications/a-generalist-agent)
[[arXiv Paper]](https://arxiv.org/pdf/2205.06175.pdf)
aper.
### Installation
```bash
$ pip install gato-torch
```
```python
import torch
from gato import Gato
#create model instance
gato = Gato(input_dim=768,
img_patch_size=16,
token_sequence_length=1024,
vocabulary_size=32000,
actions_size=1024,
continuous_values_size=1024,
num_transformer_blocks=8,
num_attention_heads=24,
layer_width=768,
feedforward_hidden_size=3072,
key_value_size=32,
dropout_rate=0.1,
num_group_norm_groups=32,
discretize_depth=128,
local_position_encoding_size=512,
max_seq_len=8192)
#fake inputs for Gato
input_dim = config.input_dim
input_ids = torch.cat([
torch.rand((1, 1, input_dim)) for _ in range(20)] + # 20 image patches
[torch.full((1, 1, input_dim), 0.25), #continous value]
torch.full((1, 1, input_dim), 624.0)] + #discrete (actions, texts)
[torch.rand((1, 1, input_dim)) for _ in range(20)] + #20 image patches
[torch.full((1, 1, input_dim), 0.12), #continous value
torch.full((1, 1, input_dim), 295.0)], #discrete( actions, text)
dim=1)
encoding = torch.tensor([
[0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])
row_pos = (
torch.tensor([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]), # pos_from
torch.tensor([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]]) # pos_to
)
col_pos = (
torch.tensor([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]), # pos_from
torch.tensor([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]]) # pos_to
)
obs = (
torch.tensor([[ 0, 1, 2, 19, 20, 21, 0, 1, 2, 19, 20, 21]]), # obs token
torch.tensor([[ 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]]) # obs token masking (for action tokens)
)
hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))
```
### Dataset and Model Architecture
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://user-images.githubusercontent.com/5837620/215323793-7f7bcfdb-d8be-40d3-8e58-a053511f95d5.png">
<img alt="gato dataset and model architecture" src="https://user-images.githubusercontent.com/5837620/215323795-3a433516-f5ca-4272-9999-3df87ae521ba.png">
</picture>
## Paper Reviews
### Full Episode Sequence
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://user-images.githubusercontent.com/5837620/175756389-31d183c9-054e-4829-93a6-df79781ca212.png">
<img alt="gato dataset architecture" src="https://user-images.githubusercontent.com/5837620/175756409-75605dbc-7756-4509-ba93-c0ad08eea309.png">
</picture>
### Architecture Variants
> Appendix C.1. Transformer Hyperparameters
In the paper, Deepmind tested Gato with 3 architecture variants, `1.18B`, `364M`, and `79M`.<br>
I have named them as `large()`, `baseline()` and `small()` respectively in `GatoConfig`.
| Hyperparameters | Large(1.18B) | Baseline(364M) | Small(79M) |
|--------------------------|--------------|----------------|------------|
| Transformer blocks | 24 | 12 | 8 |
| Attention heads | 16 | 12 | 24 |
| Layer width | 2048 | 1536 | 768 |
| Feedforward hidden size | 8192 | 6144 | 3072 |
| Key/value size | 128 | 128 | 32 |
### Residual Embedding
> Appendix C.2. Embedding Function
There are no mentions that how many residual networks must be stacked for token embeddings.<br>
Therefore, I remain configurable in `GatoConfig`.
Whatever how many residual layers are existing, full-preactivation is a key.
The blocks are consisted of:
- Version 2 ResNet architecture (based on ResNet50V2)
- GroupNorm (instead of LayerNorm)
- GeLU (instead of ReLU)
### Position Encodings
> Appendix C.3. Position Encodings
#### Patch Position Encodings
Like [Vision Transformer (ViT)](https://github.com/google-research/vision_transformer) by Google, Gato takes the input images as raster-ordered 16x16 patches.<br>
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.
For high-performance computation in TensorFlow, I have used the following expressions.
$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean `from` and `to` respectively.
$$
\begin{align}
v^R_F &= \begin{bmatrix}
0 & 32 & 64 & 96
\end{bmatrix} \\
v^R_T &= \begin{bmatrix}
32 & 64 & 96 & 128
\end{bmatrix} \\
v^C_F &= \begin{bmatrix}
0 & 26 & 51 & 77 & 102
\end{bmatrix} \\
v^C_T &= \begin{bmatrix}
26 & 51 & 77 & 102 & 128
\end{bmatrix} \\
\\
P_R &= \begin{cases}
\mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \\
\mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2})
\end{cases} \\
P_C &= \begin{cases}
\mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \\
\mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2})
\end{cases} \\
\\
E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \\
E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \\
\\
\therefore E &= E_I + E^R_P + E^C_P
\end{align}
$$
#### Local Observation Position Encodings
In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens<br>
When Gato receives those values, they must be encoded with their own (local) time steps.
## Contributing
[We welcome all contributions, please either submit a pull request or submit issues in the Agora discord](https://discord.gg/qUtxnK2NMf)
## License
Licensed under the [MIT license](/LICENSE).
# Roadmap:
* Get functional prototype
* Integrate ALIBI, multi query, qk norm and other SOTA stuff
* integrate action tokens
Raw data
{
"_id": null,
"home_page": "https://github.com/kyegomez/GATO",
"name": "gato-torch",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.10,<4.0",
"maintainer_email": "",
"keywords": "deep learning,gato,tensorflow",
"author": "Kye Gomez",
"author_email": "kye@apac.ai",
"download_url": "https://files.pythonhosted.org/packages/3a/95/4a1dd1e9725359c61ec1fffcb609e7a4c832bebee3484865a1ffd64dab60/gato_torch-0.0.2.tar.gz",
"platform": null,
"description": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n<h1 align=\"center\">Gato: A Generalist Agent</h1>\n\n[[Deepmind Publication]](https://www.deepmind.com/publications/a-generalist-agent)\n[[arXiv Paper]](https://arxiv.org/pdf/2205.06175.pdf)\n\naper.\n\n### Installation\n\n```bash\n$ pip install gato-torch\n```\n\n```python\nimport torch\nfrom gato import Gato\n\n#create model instance\ngato = Gato(input_dim=768,\n img_patch_size=16,\n token_sequence_length=1024,\n vocabulary_size=32000,\n actions_size=1024,\n continuous_values_size=1024,\n num_transformer_blocks=8,\n num_attention_heads=24,\n layer_width=768,\n feedforward_hidden_size=3072,\n key_value_size=32,\n dropout_rate=0.1,\n num_group_norm_groups=32,\n discretize_depth=128,\n local_position_encoding_size=512,\n max_seq_len=8192)\n\n\n#fake inputs for Gato\ninput_dim = config.input_dim\ninput_ids = torch.cat([\n torch.rand((1, 1, input_dim)) for _ in range(20)] + # 20 image patches\n [torch.full((1, 1, input_dim), 0.25), #continous value]\n torch.full((1, 1, input_dim), 624.0)] + #discrete (actions, texts)\n [torch.rand((1, 1, input_dim)) for _ in range(20)] + #20 image patches\n [torch.full((1, 1, input_dim), 0.12), #continous value\n torch.full((1, 1, input_dim), 295.0)], #discrete( actions, text)\n dim=1)\n\nencoding = torch.tensor([\n [0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]\n])\n\nrow_pos = (\n torch.tensor([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]), # pos_from\n torch.tensor([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]]) # pos_to\n)\n\ncol_pos = (\n torch.tensor([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]), # pos_from\n torch.tensor([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]]) # pos_to\n)\n\n\nobs = (\n torch.tensor([[ 0, 1, 2, 19, 20, 21, 0, 1, 2, 19, 20, 21]]), # obs token\n torch.tensor([[ 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]]) # obs token masking (for action tokens)\n)\n\n\nhidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))\n```\n\n\n\n### Dataset and Model Architecture\n<picture>\n <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://user-images.githubusercontent.com/5837620/215323793-7f7bcfdb-d8be-40d3-8e58-a053511f95d5.png\">\n <img alt=\"gato dataset and model architecture\" src=\"https://user-images.githubusercontent.com/5837620/215323795-3a433516-f5ca-4272-9999-3df87ae521ba.png\">\n</picture>\n\n## Paper Reviews\n\n### Full Episode Sequence\n\n<picture>\n <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://user-images.githubusercontent.com/5837620/175756389-31d183c9-054e-4829-93a6-df79781ca212.png\">\n <img alt=\"gato dataset architecture\" src=\"https://user-images.githubusercontent.com/5837620/175756409-75605dbc-7756-4509-ba93-c0ad08eea309.png\">\n</picture>\n\n### Architecture Variants\n\n> Appendix C.1. Transformer Hyperparameters\n\nIn the paper, Deepmind tested Gato with 3 architecture variants, `1.18B`, `364M`, and `79M`.<br>\nI have named them as `large()`, `baseline()` and `small()` respectively in `GatoConfig`.\n\n| Hyperparameters | Large(1.18B) | Baseline(364M) | Small(79M) |\n|--------------------------|--------------|----------------|------------|\n| Transformer blocks | 24 | 12 | 8 |\n| Attention heads | 16 | 12 | 24 |\n| Layer width | 2048 | 1536 | 768 |\n| Feedforward hidden size | 8192 | 6144 | 3072 |\n| Key/value size | 128 | 128 | 32 |\n\n\n### Residual Embedding\n\n> Appendix C.2. Embedding Function\n\nThere are no mentions that how many residual networks must be stacked for token embeddings.<br>\nTherefore, I remain configurable in `GatoConfig`.\n\nWhatever how many residual layers are existing, full-preactivation is a key.\n\nThe blocks are consisted of:\n- Version 2 ResNet architecture (based on ResNet50V2)\n- GroupNorm (instead of LayerNorm)\n- GeLU (instead of ReLU)\n\n### Position Encodings\n\n> Appendix C.3. Position Encodings\n\n#### Patch Position Encodings\n\nLike [Vision Transformer (ViT)](https://github.com/google-research/vision_transformer) by Google, Gato takes the input images as raster-ordered 16x16 patches.<br>\nUnlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.\n\nFor high-performance computation in TensorFlow, I have used the following expressions.\n\n$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean `from` and `to` respectively.\n\n$$\n\\begin{align}\n v^R_F &= \\begin{bmatrix}\n 0 & 32 & 64 & 96\n \\end{bmatrix} \\\\\n v^R_T &= \\begin{bmatrix}\n 32 & 64 & 96 & 128\n \\end{bmatrix} \\\\\n v^C_F &= \\begin{bmatrix}\n 0 & 26 & 51 & 77 & 102\n \\end{bmatrix} \\\\\n v^C_T &= \\begin{bmatrix}\n 26 & 51 & 77 & 102 & 128\n \\end{bmatrix} \\\\\n \\\\\n P_R &= \\begin{cases}\n \\mathsf{if} \\ \\mathsf{training} & v^R_F + \\mathsf{uniform}(v^R_T - v^R_F) \\\\\n \\mathsf{otherwise} & \\mathsf{round}(\\frac{v^R_F + v^R_T}{2})\n \\end{cases} \\\\\n P_C &= \\begin{cases}\n \\mathsf{if} \\ \\mathsf{training} & v^C_F + \\mathsf{uniform}(v^C_T - v^C_F) \\\\\n \\mathsf{otherwise} & \\mathsf{round}(\\frac{v^C_F + v^C_T}{2})\n \\end{cases} \\\\\n \\\\\n E^R_P &= P_R \\cdot 1^{\\mathsf{T}}_C \\\\\n E^C_P &= 1^{\\mathsf{T}}_R \\cdot P_C \\\\\n \\\\\n \\therefore E &= E_I + E^R_P + E^C_P\n\\end{align}\n$$\n\n#### Local Observation Position Encodings\n\nIn the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens<br>\nWhen Gato receives those values, they must be encoded with their own (local) time steps.\n\n\n## Contributing\n[We welcome all contributions, please either submit a pull request or submit issues in the Agora discord](https://discord.gg/qUtxnK2NMf)\n\n## License\nLicensed under the [MIT license](/LICENSE).\n\n# Roadmap:\n\n* Get functional prototype\n\n* Integrate ALIBI, multi query, qk norm and other SOTA stuff\n\n* integrate action tokens\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Gato: A Generalist Agent",
"version": "0.0.2",
"project_urls": {
"Homepage": "https://github.com/kyegomez/GATO"
},
"split_keywords": [
"deep learning",
"gato",
"tensorflow"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "9507e60fa544bd1a80f3dc384c5ef558ecb82d3bf6ff321c7f85fa52cec41309",
"md5": "3d11f63076814356db8f013dbf5a8a1a",
"sha256": "f040acf3c689966c0ce1bb64b3e388db43a760e481765407d3fbfa88cd517f68"
},
"downloads": -1,
"filename": "gato_torch-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "3d11f63076814356db8f013dbf5a8a1a",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10,<4.0",
"size": 8118,
"upload_time": "2023-08-25T22:31:13",
"upload_time_iso_8601": "2023-08-25T22:31:13.427014Z",
"url": "https://files.pythonhosted.org/packages/95/07/e60fa544bd1a80f3dc384c5ef558ecb82d3bf6ff321c7f85fa52cec41309/gato_torch-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "3a954a1dd1e9725359c61ec1fffcb609e7a4c832bebee3484865a1ffd64dab60",
"md5": "c909c2d9eea5e79d316513b713687731",
"sha256": "e29746f33ef7406934bd7e65e0dd916baa9b1ee0fd35169759631a910f64b317"
},
"downloads": -1,
"filename": "gato_torch-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "c909c2d9eea5e79d316513b713687731",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10,<4.0",
"size": 10450,
"upload_time": "2023-08-25T22:31:14",
"upload_time_iso_8601": "2023-08-25T22:31:14.917223Z",
"url": "https://files.pythonhosted.org/packages/3a/95/4a1dd1e9725359c61ec1fffcb609e7a4c832bebee3484865a1ffd64dab60/gato_torch-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-08-25 22:31:14",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kyegomez",
"github_project": "GATO",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "gato-torch"
}