# GateLoop
<div style="display: flex; justify-content: space-around; align-items: center;">
<img src="assets/gate_loop.png" alt="GateLoop" width="300px"/>
<img src="assets/eq.png" alt="Equation" width="400px"/>
</div>
> **GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling**\
> Tobias Katsch*\
> Paper: https://arxiv.org/abs/2311.01927
## About
GateLoop is a novel sequence model which generalizes linear recurrent models such as S4, S5, LRU and RetNet, by employing data-controlled state transitions.
While having a low cost linear complexity inference mode, GateLoop can be trained extremely efficient in parallel with logarithmic complexity making use of the highly optimized JAX associative scan implementation.
This repository implements a practical (real valued) GateLoop model with proper default choices for input-, hidden- and gate activations and provides a drop-in replacement for causal multi-head-attention and a GateLoop-based language model architecture.
## Installation
- `pip install flax-gate-loop`: The core GateLoop package.
Other requirements:
- JAX 0.4.20+
- FLAX 0.8.0+
## Usage
We provide 2 main modules:
- ### [gate_loop.py](flax_gate_loop/gate_loop.py)
A causal time mixing sequence model which can be used as a drop-in replacement for causal multi-head-attention.
Usage:
```
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax_gate_loop import GateLoop
batch_size, sequence_length, input_dim, d_h = 2, 64, 16, 32
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (batch_size, sequence_length, input_dim))
model = GateLoop(
d_h=d_h,
input_activation=nn.tanh,
hidden_activation=nn.tanh,
gate_activation=nn.sigmoid,
use_true_recurrence=False,
use_tied_gates=True,
)
params = model.init(jax.random.PRNGKey(1), x)
y = model.apply(params, x)
assert y.shape == (batch_size, sequence_length, d_h)
```
#### Two Stage Training
- **Associative Recurrent Mode:** (`use_true_recurrence=False`) Extremely efficient training through parallel scan. This disables the recurrent weights, allowing for much fast training compared to Transformer, GRU & LSTM.
- **True Recurrent Mode:** (`use_true_recurrence=True`) Can be used to train a more expressive model from a Linear Recurrent Model checkpoint. This variant introduces additional parameters such that gates also depend on previous hidden states similar to GRU & LSTM. Due to the true recurrent nature, this mode cannot be parallelized and thus is less efficient. We recommend this for finetuning from an linear recurrent checkpoint.
#### Gate Tying
- **Disjoint Input & Forget gate** (`use_tied_gates=Flase`) Applies seperate projections for input- & forget gates
- **Tied Input & Forget gate** (`use_tied_gates=True`) Ties the input and forget gate through the relation `forget_gate = 1-input_gate`.
- ## [gate_loop_lm.py](flax_gate_loop/language_models/gate_loop_lm.py)
A GateLoop-based language model.
```
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax_gate_loop import GateLoopLM
# Model parameters
n_layer = 4
d_model = 256
d_channel_mixing = 64
eps = 1e-6
channel_mixing_dropout = 0.1
time_mixing_dropout = 0.1
input_vocab_size = 10000
output_vocab_size = 10000
max_seq_length = 512
embedding_dropout = 0.1
use_word_embedding = True
positional_encoding_mode = 'none' # 'none', 'learned', 'sinusoidal'
d_h = 256
model = GateLoopLM(
n_layer=n_layer,
d_model=d_model,
d_channel_mixing=d_channel_mixing,
eps=eps,
channel_mixing_dropout=channel_mixing_dropout,
time_mixing_dropout=time_mixing_dropout,
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
max_seq_length=max_seq_length,
embedding_dropout=embedding_dropout,
use_word_embedding=use_word_embedding,
positional_encoding_mode=positional_encoding_mode,
d_h=d_h,
input_activation=nn.tanh,
hidden_activation=nn.tanh,
gate_activation=nn.sigmoid,
use_true_recurrence=False,
use_tied_gates=True,
)
# Sample input
batch_size = 32
x = jax.random.randint(jax.random.PRNGKey(0), (batch_size, max_seq_length), 0, input_vocab_size)
# Initialize and apply model
params = model.init(jax.random.PRNGKey(2), x, training=False)
y = model.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})
assert y.shape == (batch_size, max_seq_length, output_vocab_size)
```
## Citation
If you use this codebase, please cite:
```
@misc{katsch2024gateloop,
title={GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
author={Tobias Katsch},
year={2024},
eprint={2311.01927},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/tobiaskatsch/GateLoop",
"name": "flax-gate-loop",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "",
"author": "Tobias Katsch",
"author_email": "tobias.katsch42@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/18/a2/31915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81/flax-gate-loop-1.0.3.tar.gz",
"platform": null,
"description": "# GateLoop\n\n<div style=\"display: flex; justify-content: space-around; align-items: center;\">\n <img src=\"assets/gate_loop.png\" alt=\"GateLoop\" width=\"300px\"/>\n <img src=\"assets/eq.png\" alt=\"Equation\" width=\"400px\"/>\n</div>\n\n> **GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling**\\\n> Tobias Katsch*\\\n> Paper: https://arxiv.org/abs/2311.01927\n\n## About\n\n\nGateLoop is a novel sequence model which generalizes linear recurrent models such as S4, S5, LRU and RetNet, by employing data-controlled state transitions.\nWhile having a low cost linear complexity inference mode, GateLoop can be trained extremely efficient in parallel with logarithmic complexity making use of the highly optimized JAX associative scan implementation.\nThis repository implements a practical (real valued) GateLoop model with proper default choices for input-, hidden- and gate activations and provides a drop-in replacement for causal multi-head-attention and a GateLoop-based language model architecture. \n\n## Installation\n\n- `pip install flax-gate-loop`: The core GateLoop package.\n\nOther requirements:\n- JAX 0.4.20+\n- FLAX 0.8.0+\n\n## Usage\nWe provide 2 main modules:\n- ### [gate_loop.py](flax_gate_loop/gate_loop.py)\n A causal time mixing sequence model which can be used as a drop-in replacement for causal multi-head-attention.\n Usage:\n ```\n import jax\n import jax.numpy as jnp\n from flax import linen as nn\n from flax_gate_loop import GateLoop\n \n batch_size, sequence_length, input_dim, d_h = 2, 64, 16, 32\n key = jax.random.PRNGKey(0)\n x = jax.random.normal(key, (batch_size, sequence_length, input_dim))\n \n model = GateLoop(\n d_h=d_h,\n input_activation=nn.tanh,\n hidden_activation=nn.tanh,\n gate_activation=nn.sigmoid,\n use_true_recurrence=False,\n use_tied_gates=True,\n )\n \n params = model.init(jax.random.PRNGKey(1), x)\n y = model.apply(params, x)\n assert y.shape == (batch_size, sequence_length, d_h)\n ```\n #### Two Stage Training\n - **Associative Recurrent Mode:** (`use_true_recurrence=False`) Extremely efficient training through parallel scan. This disables the recurrent weights, allowing for much fast training compared to Transformer, GRU & LSTM.\n - **True Recurrent Mode:** (`use_true_recurrence=True`) Can be used to train a more expressive model from a Linear Recurrent Model checkpoint. This variant introduces additional parameters such that gates also depend on previous hidden states similar to GRU & LSTM. Due to the true recurrent nature, this mode cannot be parallelized and thus is less efficient. We recommend this for finetuning from an linear recurrent checkpoint.\n\n #### Gate Tying\n - **Disjoint Input & Forget gate** (`use_tied_gates=Flase`) Applies seperate projections for input- & forget gates\n - **Tied Input & Forget gate** (`use_tied_gates=True`) Ties the input and forget gate through the relation `forget_gate = 1-input_gate`.\n\n\n- ## [gate_loop_lm.py](flax_gate_loop/language_models/gate_loop_lm.py)\n A GateLoop-based language model.\n ```\n import jax\n import jax.numpy as jnp\n from flax import linen as nn\n from flax_gate_loop import GateLoopLM\n \n # Model parameters\n n_layer = 4\n d_model = 256\n d_channel_mixing = 64\n eps = 1e-6\n channel_mixing_dropout = 0.1\n time_mixing_dropout = 0.1\n input_vocab_size = 10000\n output_vocab_size = 10000\n max_seq_length = 512\n embedding_dropout = 0.1\n use_word_embedding = True\n positional_encoding_mode = 'none' # 'none', 'learned', 'sinusoidal'\n d_h = 256\n \n model = GateLoopLM(\n n_layer=n_layer,\n d_model=d_model,\n d_channel_mixing=d_channel_mixing,\n eps=eps,\n channel_mixing_dropout=channel_mixing_dropout,\n time_mixing_dropout=time_mixing_dropout,\n input_vocab_size=input_vocab_size,\n output_vocab_size=output_vocab_size,\n max_seq_length=max_seq_length,\n embedding_dropout=embedding_dropout,\n use_word_embedding=use_word_embedding,\n positional_encoding_mode=positional_encoding_mode,\n d_h=d_h,\n input_activation=nn.tanh,\n hidden_activation=nn.tanh,\n gate_activation=nn.sigmoid,\n use_true_recurrence=False,\n use_tied_gates=True,\n )\n \n # Sample input\n batch_size = 32\n x = jax.random.randint(jax.random.PRNGKey(0), (batch_size, max_seq_length), 0, input_vocab_size)\n \n # Initialize and apply model\n params = model.init(jax.random.PRNGKey(2), x, training=False)\n y = model.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})\n assert y.shape == (batch_size, max_seq_length, output_vocab_size)\n ```\n\n## Citation\n\nIf you use this codebase, please cite:\n```\n@misc{katsch2024gateloop,\n title={GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling}, \n author={Tobias Katsch},\n year={2024},\n eprint={2311.01927},\n archivePrefix={arXiv},\n primaryClass={cs.LG}\n}\n```\n",
"bugtrack_url": null,
"license": "Apache License, Version 2.0",
"summary": "GateLoop Model",
"version": "1.0.3",
"project_urls": {
"Homepage": "https://github.com/tobiaskatsch/GateLoop"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "40e9b0ad791e65cfd1c57db7f79f0f055f5d89a8dad7ae3740ad32bbd4870514",
"md5": "4a2250ef15a9b0617c72690548104004",
"sha256": "24287778c328fd8e9296e7d0f8f05a0ea29e1277fc8cbb23faf14a8f1bb1ffa4"
},
"downloads": -1,
"filename": "flax_gate_loop-1.0.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "4a2250ef15a9b0617c72690548104004",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 14146,
"upload_time": "2024-02-09T05:25:33",
"upload_time_iso_8601": "2024-02-09T05:25:33.557511Z",
"url": "https://files.pythonhosted.org/packages/40/e9/b0ad791e65cfd1c57db7f79f0f055f5d89a8dad7ae3740ad32bbd4870514/flax_gate_loop-1.0.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "18a231915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81",
"md5": "839b8e4979b1ce12547b873cb245302c",
"sha256": "4c59dc84cce64520f1e59935470a168c2ce7c5c500f1cdeb25590eab94050f97"
},
"downloads": -1,
"filename": "flax-gate-loop-1.0.3.tar.gz",
"has_sig": false,
"md5_digest": "839b8e4979b1ce12547b873cb245302c",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 12193,
"upload_time": "2024-02-09T05:25:35",
"upload_time_iso_8601": "2024-02-09T05:25:35.072404Z",
"url": "https://files.pythonhosted.org/packages/18/a2/31915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81/flax-gate-loop-1.0.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-02-09 05:25:35",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "tobiaskatsch",
"github_project": "GateLoop",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "flax-gate-loop"
}