Name | givt-pytorch JSON |
Version |
0.0.3
JSON |
| download |
home_page | https://github.com/elyxlz/givt-pytorch |
Summary | A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch. |
upload_time | 2024-01-08 01:25:15 |
maintainer | |
docs_url | None |
author | Elio Pascarelli |
requires_python | >=3.9 |
license | MIT |
keywords |
|
VCS |
|
bugtrack_url |
|
requirements |
No requirements were recorded.
|
Travis-CI |
No Travis.
|
coveralls test coverage |
No coveralls.
|
## GIVT-PyTorch
A partial implementation of [Generative Infinite Vocabulary Transformer (GIVT)](https://arxiv.org/abs/2312.02116) from Google Deepmind, in PyTorch.
This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.
The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.
## Install
```sh
# for inference
pip install .
# for training/development
pip install -e '.[train]'
```
## Usage
```py
from givt_pytorch import GIVT
# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')
latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss
prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
prompt=prompt,
max_len=500,
cfg_scale=0.5,
temperature=0.95,
) # (500, 32)
```
## Training
Define a config file in `configs/`, such as this one:
```py
from givt_pytorch import (
GIVT,
GIVTConfig,
DummyDataset,
Trainer,
TrainConfig
)
model = GIVT(GIVTConfig())
dataset = DummyDataset()
trainer = Trainer(
model=model,
dataset=dataset,
train_config=TrainConfig()
)
```
Create an accelerate config.
```sh
accelerate config
```
And then run the training.
```sh
accelerate launch train.py {config_name}
```
## TODO
- [ ] Test out with latents from an audio vae
- [ ] Add CFG with rejection sampling
## References
```bibtex
@misc{litgpt2024,
title={lit-gpt on GitHub},
url={https://github.com/Lightning-AI/lit-gpt},
year={2024}
@misc{tschannen2023givt,
title = {GIVT: Generative Infinite-Vocabulary Transformers},
author = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
year = {2023},
eprint = {2312.02116},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/elyxlz/givt-pytorch",
"name": "givt-pytorch",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": "",
"keywords": "",
"author": "Elio Pascarelli",
"author_email": "elio@pascarelli.com",
"download_url": "https://files.pythonhosted.org/packages/f4/d5/44be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d/givt_pytorch-0.0.3.tar.gz",
"platform": null,
"description": "## GIVT-PyTorch\nA partial implementation of [Generative Infinite Vocabulary Transformer (GIVT)](https://arxiv.org/abs/2312.02116) from Google Deepmind, in PyTorch.\n\nThis repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.\n\nThe decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.\n\n## Install\n```sh\n# for inference\npip install .\n\n# for training/development\npip install -e '.[train]'\n```\n\n\n## Usage\n```py\nfrom givt_pytorch import GIVT\n\n# load pretrained checkpoint\nmodel = GIVT.from_pretrained('elyxlz/givt-test')\n\nlatents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)\nloss = model.forward(latents).loss # NLL Loss\n\nprompt = torch.randn((50, 32)) # no batched inference implemented\ngenerated = model.generate(\n prompt=prompt, \n max_len=500,\n cfg_scale=0.5,\n temperature=0.95,\n) # (500, 32)\n```\n\n## Training\n\nDefine a config file in `configs/`, such as this one:\n```py\nfrom givt_pytorch import (\n GIVT,\n GIVTConfig,\n DummyDataset,\n Trainer,\n TrainConfig\n)\n\nmodel = GIVT(GIVTConfig())\ndataset = DummyDataset()\n\ntrainer = Trainer(\n model=model,\n dataset=dataset, \n train_config=TrainConfig()\n)\n```\nCreate an accelerate config.\n```sh\naccelerate config\n```\n\nAnd then run the training.\n```sh\naccelerate launch train.py {config_name}\n```\n\n## TODO\n- [ ] Test out with latents from an audio vae\n- [ ] Add CFG with rejection sampling\n\n## References\n```bibtex\n@misc{litgpt2024,\n title={lit-gpt on GitHub},\n url={https://github.com/Lightning-AI/lit-gpt},\n year={2024}\n\n@misc{tschannen2023givt,\n title = {GIVT: Generative Infinite-Vocabulary Transformers}, \n author = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},\n year = {2023},\n eprint = {2312.02116},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.",
"version": "0.0.3",
"project_urls": {
"Homepage": "https://github.com/elyxlz/givt-pytorch"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "607031a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d",
"md5": "5a5b9ab4748159ed7d6a530fc14441a6",
"sha256": "3e88e10598998a55fb28eee8b8552f604f623fdbc92c83b8440dd2071033e69c"
},
"downloads": -1,
"filename": "givt_pytorch-0.0.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "5a5b9ab4748159ed7d6a530fc14441a6",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 11815,
"upload_time": "2024-01-08T01:25:14",
"upload_time_iso_8601": "2024-01-08T01:25:14.020749Z",
"url": "https://files.pythonhosted.org/packages/60/70/31a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d/givt_pytorch-0.0.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "f4d544be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d",
"md5": "71b6ba1b6c88ef1ac59c17f155293319",
"sha256": "89aba1029ae2d3e410c92d2def8b2e3cccc22a3c385383b649b73fb40b9b4e4c"
},
"downloads": -1,
"filename": "givt_pytorch-0.0.3.tar.gz",
"has_sig": false,
"md5_digest": "71b6ba1b6c88ef1ac59c17f155293319",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 11940,
"upload_time": "2024-01-08T01:25:15",
"upload_time_iso_8601": "2024-01-08T01:25:15.178642Z",
"url": "https://files.pythonhosted.org/packages/f4/d5/44be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d/givt_pytorch-0.0.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-01-08 01:25:15",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "elyxlz",
"github_project": "givt-pytorch",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "givt-pytorch"
}