flax-gate-loop


Nameflax-gate-loop JSON
Version 1.0.3 PyPI version JSON
download
home_pagehttps://github.com/tobiaskatsch/GateLoop
SummaryGateLoop Model
upload_time2024-02-09 05:25:35
maintainer
docs_urlNone
authorTobias Katsch
requires_python>=3.7
licenseApache License, Version 2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # GateLoop

<div style="display: flex; justify-content: space-around; align-items: center;">
    <img src="assets/gate_loop.png" alt="GateLoop" width="300px"/>
    <img src="assets/eq.png" alt="Equation" width="400px"/>
</div>

> **GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling**\
> Tobias Katsch*\
> Paper: https://arxiv.org/abs/2311.01927

## About


GateLoop is a novel sequence model which generalizes linear recurrent models such as S4, S5, LRU and RetNet, by employing data-controlled state transitions.
While having a low cost linear complexity inference mode, GateLoop can be trained extremely efficient in parallel with logarithmic complexity making use of the highly optimized JAX associative scan implementation.
This repository implements a practical (real valued) GateLoop model with proper default choices for input-, hidden- and gate activations and provides a drop-in replacement for causal multi-head-attention and a GateLoop-based language model architecture. 

## Installation

- `pip install flax-gate-loop`: The core GateLoop package.

Other requirements:
- JAX 0.4.20+
- FLAX 0.8.0+

## Usage
We provide 2 main modules:
- ### [gate_loop.py](flax_gate_loop/gate_loop.py)
  A causal time mixing sequence model which can be used as a drop-in replacement for causal multi-head-attention.
  Usage:
  ```
  import jax
  import jax.numpy as jnp
  from flax import linen as nn
  from flax_gate_loop import GateLoop
  
  batch_size, sequence_length, input_dim, d_h = 2, 64, 16, 32
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (batch_size, sequence_length, input_dim))
  
  model = GateLoop(
      d_h=d_h,
      input_activation=nn.tanh,
      hidden_activation=nn.tanh,
      gate_activation=nn.sigmoid,
      use_true_recurrence=False,
      use_tied_gates=True,
  )
  
  params = model.init(jax.random.PRNGKey(1), x)
  y = model.apply(params, x)
  assert y.shape == (batch_size, sequence_length, d_h)
  ```
  #### Two Stage Training
  - **Associative Recurrent Mode:** (`use_true_recurrence=False`) Extremely efficient training through parallel scan. This disables the recurrent weights, allowing for much fast training compared to Transformer, GRU & LSTM.
  - **True Recurrent Mode:** (`use_true_recurrence=True`) Can be used to train a more expressive model from a Linear Recurrent Model checkpoint. This variant introduces additional parameters such that gates also depend on previous hidden states similar to GRU & LSTM. Due to the true recurrent nature, this mode cannot be parallelized and thus is less efficient. We recommend this for finetuning from an linear recurrent checkpoint.

  #### Gate Tying
  - **Disjoint Input & Forget gate** (`use_tied_gates=Flase`) Applies seperate projections for input- & forget gates
  - **Tied Input & Forget gate** (`use_tied_gates=True`) Ties the input and forget gate through the relation `forget_gate = 1-input_gate`.


- ## [gate_loop_lm.py](flax_gate_loop/language_models/gate_loop_lm.py)
  A GateLoop-based language model.
  ```
  import jax
  import jax.numpy as jnp
  from flax import linen as nn
  from flax_gate_loop import GateLoopLM
  
  # Model parameters
  n_layer = 4
  d_model = 256
  d_channel_mixing = 64
  eps = 1e-6
  channel_mixing_dropout = 0.1
  time_mixing_dropout = 0.1
  input_vocab_size = 10000
  output_vocab_size = 10000
  max_seq_length = 512
  embedding_dropout = 0.1
  use_word_embedding = True
  positional_encoding_mode = 'none'  # 'none', 'learned', 'sinusoidal'
  d_h = 256
  
  model = GateLoopLM(
    n_layer=n_layer,
    d_model=d_model,
    d_channel_mixing=d_channel_mixing,
    eps=eps,
    channel_mixing_dropout=channel_mixing_dropout,
    time_mixing_dropout=time_mixing_dropout,
    input_vocab_size=input_vocab_size,
    output_vocab_size=output_vocab_size,
    max_seq_length=max_seq_length,
    embedding_dropout=embedding_dropout,
    use_word_embedding=use_word_embedding,
    positional_encoding_mode=positional_encoding_mode,
    d_h=d_h,
    input_activation=nn.tanh,
    hidden_activation=nn.tanh,
    gate_activation=nn.sigmoid,
    use_true_recurrence=False,
    use_tied_gates=True,
  )
  
  # Sample input
  batch_size = 32
  x = jax.random.randint(jax.random.PRNGKey(0), (batch_size, max_seq_length), 0, input_vocab_size)
  
  # Initialize and apply model
  params = model.init(jax.random.PRNGKey(2), x, training=False)
  y = model.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})
  assert y.shape == (batch_size, max_seq_length, output_vocab_size)
  ```

## Citation

If you use this codebase, please cite:
```
@misc{katsch2024gateloop,
      title={GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling}, 
      author={Tobias Katsch},
      year={2024},
      eprint={2311.01927},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/tobiaskatsch/GateLoop",
    "name": "flax-gate-loop",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "",
    "keywords": "",
    "author": "Tobias Katsch",
    "author_email": "tobias.katsch42@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/18/a2/31915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81/flax-gate-loop-1.0.3.tar.gz",
    "platform": null,
    "description": "# GateLoop\n\n<div style=\"display: flex; justify-content: space-around; align-items: center;\">\n    <img src=\"assets/gate_loop.png\" alt=\"GateLoop\" width=\"300px\"/>\n    <img src=\"assets/eq.png\" alt=\"Equation\" width=\"400px\"/>\n</div>\n\n> **GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling**\\\n> Tobias Katsch*\\\n> Paper: https://arxiv.org/abs/2311.01927\n\n## About\n\n\nGateLoop is a novel sequence model which generalizes linear recurrent models such as S4, S5, LRU and RetNet, by employing data-controlled state transitions.\nWhile having a low cost linear complexity inference mode, GateLoop can be trained extremely efficient in parallel with logarithmic complexity making use of the highly optimized JAX associative scan implementation.\nThis repository implements a practical (real valued) GateLoop model with proper default choices for input-, hidden- and gate activations and provides a drop-in replacement for causal multi-head-attention and a GateLoop-based language model architecture. \n\n## Installation\n\n- `pip install flax-gate-loop`: The core GateLoop package.\n\nOther requirements:\n- JAX 0.4.20+\n- FLAX 0.8.0+\n\n## Usage\nWe provide 2 main modules:\n- ### [gate_loop.py](flax_gate_loop/gate_loop.py)\n  A causal time mixing sequence model which can be used as a drop-in replacement for causal multi-head-attention.\n  Usage:\n  ```\n  import jax\n  import jax.numpy as jnp\n  from flax import linen as nn\n  from flax_gate_loop import GateLoop\n  \n  batch_size, sequence_length, input_dim, d_h = 2, 64, 16, 32\n  key = jax.random.PRNGKey(0)\n  x = jax.random.normal(key, (batch_size, sequence_length, input_dim))\n  \n  model = GateLoop(\n      d_h=d_h,\n      input_activation=nn.tanh,\n      hidden_activation=nn.tanh,\n      gate_activation=nn.sigmoid,\n      use_true_recurrence=False,\n      use_tied_gates=True,\n  )\n  \n  params = model.init(jax.random.PRNGKey(1), x)\n  y = model.apply(params, x)\n  assert y.shape == (batch_size, sequence_length, d_h)\n  ```\n  #### Two Stage Training\n  - **Associative Recurrent Mode:** (`use_true_recurrence=False`) Extremely efficient training through parallel scan. This disables the recurrent weights, allowing for much fast training compared to Transformer, GRU & LSTM.\n  - **True Recurrent Mode:** (`use_true_recurrence=True`) Can be used to train a more expressive model from a Linear Recurrent Model checkpoint. This variant introduces additional parameters such that gates also depend on previous hidden states similar to GRU & LSTM. Due to the true recurrent nature, this mode cannot be parallelized and thus is less efficient. We recommend this for finetuning from an linear recurrent checkpoint.\n\n  #### Gate Tying\n  - **Disjoint Input & Forget gate** (`use_tied_gates=Flase`) Applies seperate projections for input- & forget gates\n  - **Tied Input & Forget gate** (`use_tied_gates=True`) Ties the input and forget gate through the relation `forget_gate = 1-input_gate`.\n\n\n- ## [gate_loop_lm.py](flax_gate_loop/language_models/gate_loop_lm.py)\n  A GateLoop-based language model.\n  ```\n  import jax\n  import jax.numpy as jnp\n  from flax import linen as nn\n  from flax_gate_loop import GateLoopLM\n  \n  # Model parameters\n  n_layer = 4\n  d_model = 256\n  d_channel_mixing = 64\n  eps = 1e-6\n  channel_mixing_dropout = 0.1\n  time_mixing_dropout = 0.1\n  input_vocab_size = 10000\n  output_vocab_size = 10000\n  max_seq_length = 512\n  embedding_dropout = 0.1\n  use_word_embedding = True\n  positional_encoding_mode = 'none'  # 'none', 'learned', 'sinusoidal'\n  d_h = 256\n  \n  model = GateLoopLM(\n    n_layer=n_layer,\n    d_model=d_model,\n    d_channel_mixing=d_channel_mixing,\n    eps=eps,\n    channel_mixing_dropout=channel_mixing_dropout,\n    time_mixing_dropout=time_mixing_dropout,\n    input_vocab_size=input_vocab_size,\n    output_vocab_size=output_vocab_size,\n    max_seq_length=max_seq_length,\n    embedding_dropout=embedding_dropout,\n    use_word_embedding=use_word_embedding,\n    positional_encoding_mode=positional_encoding_mode,\n    d_h=d_h,\n    input_activation=nn.tanh,\n    hidden_activation=nn.tanh,\n    gate_activation=nn.sigmoid,\n    use_true_recurrence=False,\n    use_tied_gates=True,\n  )\n  \n  # Sample input\n  batch_size = 32\n  x = jax.random.randint(jax.random.PRNGKey(0), (batch_size, max_seq_length), 0, input_vocab_size)\n  \n  # Initialize and apply model\n  params = model.init(jax.random.PRNGKey(2), x, training=False)\n  y = model.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(0)})\n  assert y.shape == (batch_size, max_seq_length, output_vocab_size)\n  ```\n\n## Citation\n\nIf you use this codebase, please cite:\n```\n@misc{katsch2024gateloop,\n      title={GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling}, \n      author={Tobias Katsch},\n      year={2024},\n      eprint={2311.01927},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache License, Version 2.0",
    "summary": "GateLoop Model",
    "version": "1.0.3",
    "project_urls": {
        "Homepage": "https://github.com/tobiaskatsch/GateLoop"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "40e9b0ad791e65cfd1c57db7f79f0f055f5d89a8dad7ae3740ad32bbd4870514",
                "md5": "4a2250ef15a9b0617c72690548104004",
                "sha256": "24287778c328fd8e9296e7d0f8f05a0ea29e1277fc8cbb23faf14a8f1bb1ffa4"
            },
            "downloads": -1,
            "filename": "flax_gate_loop-1.0.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "4a2250ef15a9b0617c72690548104004",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 14146,
            "upload_time": "2024-02-09T05:25:33",
            "upload_time_iso_8601": "2024-02-09T05:25:33.557511Z",
            "url": "https://files.pythonhosted.org/packages/40/e9/b0ad791e65cfd1c57db7f79f0f055f5d89a8dad7ae3740ad32bbd4870514/flax_gate_loop-1.0.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "18a231915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81",
                "md5": "839b8e4979b1ce12547b873cb245302c",
                "sha256": "4c59dc84cce64520f1e59935470a168c2ce7c5c500f1cdeb25590eab94050f97"
            },
            "downloads": -1,
            "filename": "flax-gate-loop-1.0.3.tar.gz",
            "has_sig": false,
            "md5_digest": "839b8e4979b1ce12547b873cb245302c",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 12193,
            "upload_time": "2024-02-09T05:25:35",
            "upload_time_iso_8601": "2024-02-09T05:25:35.072404Z",
            "url": "https://files.pythonhosted.org/packages/18/a2/31915bb04ba77e17f9fc450ddeac742c7a577f24a086200e8fd25b981f81/flax-gate-loop-1.0.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-09 05:25:35",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "tobiaskatsch",
    "github_project": "GateLoop",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "flax-gate-loop"
}
        
Elapsed time: 0.16988s