givt-pytorch


Namegivt-pytorch JSON
Version 0.0.3 PyPI version JSON
download
home_pagehttps://github.com/elyxlz/givt-pytorch
SummaryA partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.
upload_time2024-01-08 01:25:15
maintainer
docs_urlNone
authorElio Pascarelli
requires_python>=3.9
licenseMIT
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            ## GIVT-PyTorch
A partial implementation of [Generative Infinite Vocabulary Transformer (GIVT)](https://arxiv.org/abs/2312.02116) from Google Deepmind, in PyTorch.

This repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.

The decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.

## Install
```sh
# for inference
pip install .

# for training/development
pip install -e '.[train]'
```


## Usage
```py
from givt_pytorch import GIVT

# load pretrained checkpoint
model = GIVT.from_pretrained('elyxlz/givt-test')

latents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)
loss = model.forward(latents).loss # NLL Loss

prompt = torch.randn((50, 32)) # no batched inference implemented
generated = model.generate(
    prompt=prompt, 
    max_len=500,
    cfg_scale=0.5,
    temperature=0.95,
) # (500, 32)
```

## Training

Define a config file in `configs/`, such as this one:
```py
from givt_pytorch import (
    GIVT,
    GIVTConfig,
    DummyDataset,
    Trainer,
    TrainConfig
)

model = GIVT(GIVTConfig())
dataset = DummyDataset()

trainer = Trainer(
    model=model,
    dataset=dataset,    
    train_config=TrainConfig()
)
```
Create an accelerate config.
```sh
accelerate config
```

And then run the training.
```sh
accelerate launch train.py {config_name}
```

## TODO
- [ ] Test out with latents from an audio vae
- [ ] Add CFG with rejection sampling

## References
```bibtex
@misc{litgpt2024,
  title={lit-gpt on GitHub},
  url={https://github.com/Lightning-AI/lit-gpt},
  year={2024}

@misc{tschannen2023givt,
    title   = {GIVT: Generative Infinite-Vocabulary Transformers}, 
    author  = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},
    year    = {2023},
    eprint  = {2312.02116},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/elyxlz/givt-pytorch",
    "name": "givt-pytorch",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "",
    "author": "Elio Pascarelli",
    "author_email": "elio@pascarelli.com",
    "download_url": "https://files.pythonhosted.org/packages/f4/d5/44be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d/givt_pytorch-0.0.3.tar.gz",
    "platform": null,
    "description": "## GIVT-PyTorch\nA partial implementation of [Generative Infinite Vocabulary Transformer (GIVT)](https://arxiv.org/abs/2312.02116) from Google Deepmind, in PyTorch.\n\nThis repo only implements the causal version of GIVT, and does away with the k mixtures predictions or the use of the full covariance matrix, as for most purposes they did not yield significantly better results.\n\nThe decoder transformer implementation is also modernized, adopting a Llama style architecture with gated MLPs, SiLU, RMSNorm, and RoPE.\n\n## Install\n```sh\n# for inference\npip install .\n\n# for training/development\npip install -e '.[train]'\n```\n\n\n## Usage\n```py\nfrom givt_pytorch import GIVT\n\n# load pretrained checkpoint\nmodel = GIVT.from_pretrained('elyxlz/givt-test')\n\nlatents = torch.randn((4, 500, 32)) # vae latents (bs, seq_len, size)\nloss = model.forward(latents).loss # NLL Loss\n\nprompt = torch.randn((50, 32)) # no batched inference implemented\ngenerated = model.generate(\n    prompt=prompt, \n    max_len=500,\n    cfg_scale=0.5,\n    temperature=0.95,\n) # (500, 32)\n```\n\n## Training\n\nDefine a config file in `configs/`, such as this one:\n```py\nfrom givt_pytorch import (\n    GIVT,\n    GIVTConfig,\n    DummyDataset,\n    Trainer,\n    TrainConfig\n)\n\nmodel = GIVT(GIVTConfig())\ndataset = DummyDataset()\n\ntrainer = Trainer(\n    model=model,\n    dataset=dataset,    \n    train_config=TrainConfig()\n)\n```\nCreate an accelerate config.\n```sh\naccelerate config\n```\n\nAnd then run the training.\n```sh\naccelerate launch train.py {config_name}\n```\n\n## TODO\n- [ ] Test out with latents from an audio vae\n- [ ] Add CFG with rejection sampling\n\n## References\n```bibtex\n@misc{litgpt2024,\n  title={lit-gpt on GitHub},\n  url={https://github.com/Lightning-AI/lit-gpt},\n  year={2024}\n\n@misc{tschannen2023givt,\n    title   = {GIVT: Generative Infinite-Vocabulary Transformers}, \n    author  = {Michael Tschannen, Cian Eastwood, Fabian Mentzer},\n    year    = {2023},\n    eprint  = {2312.02116},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CV}\n}\n```\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "A partial implementation of Generative Infinite Vocabulary Transformer (GIVT) from Google Deepmind, in PyTorch.",
    "version": "0.0.3",
    "project_urls": {
        "Homepage": "https://github.com/elyxlz/givt-pytorch"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "607031a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d",
                "md5": "5a5b9ab4748159ed7d6a530fc14441a6",
                "sha256": "3e88e10598998a55fb28eee8b8552f604f623fdbc92c83b8440dd2071033e69c"
            },
            "downloads": -1,
            "filename": "givt_pytorch-0.0.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "5a5b9ab4748159ed7d6a530fc14441a6",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 11815,
            "upload_time": "2024-01-08T01:25:14",
            "upload_time_iso_8601": "2024-01-08T01:25:14.020749Z",
            "url": "https://files.pythonhosted.org/packages/60/70/31a60806155cb2bc18dc8cceb5f1954982620eddf741b40492201552364d/givt_pytorch-0.0.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "f4d544be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d",
                "md5": "71b6ba1b6c88ef1ac59c17f155293319",
                "sha256": "89aba1029ae2d3e410c92d2def8b2e3cccc22a3c385383b649b73fb40b9b4e4c"
            },
            "downloads": -1,
            "filename": "givt_pytorch-0.0.3.tar.gz",
            "has_sig": false,
            "md5_digest": "71b6ba1b6c88ef1ac59c17f155293319",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 11940,
            "upload_time": "2024-01-08T01:25:15",
            "upload_time_iso_8601": "2024-01-08T01:25:15.178642Z",
            "url": "https://files.pythonhosted.org/packages/f4/d5/44be0a22d72a9bc47757ec5a36e2fd900ffebfd6d2941c6bf6733b2c443d/givt_pytorch-0.0.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-01-08 01:25:15",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "elyxlz",
    "github_project": "givt-pytorch",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "givt-pytorch"
}
        
Elapsed time: 0.20680s