[![Multi-Modality](agorabanner.png)](https://discord.com/servers/agora-999382051935506503)
# Jax Transformer
[![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/agora-999382051935506503) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@kyegomez3242) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/kye-g-38759a207/) [![Follow on X.com](https://img.shields.io/badge/X.com-Follow-1DA1F2?style=for-the-badge&logo=x&logoColor=white)](https://x.com/kyegomezb)
This repository demonstrates how to build a **Decoder-Only Transformer** with **Multi-Query Attention** in **JAX**. Multi-Query Attention is an efficient variant of the traditional multi-head attention, where all attention heads share the same key-value pairs, but maintain separate query projections.
## Table of Contents
- [Overview](#overview)
- [Key Concepts](#key-concepts)
- [Installation](#installation)
- [Usage](#usage)
- [Code Walkthrough](#code-walkthrough)
- [Multi-Query Attention](#multi-query-attention)
- [Feed-Forward Layer](#feed-forward-layer)
- [Transformer Decoder Block](#transformer-decoder-block)
- [Causal Masking](#causal-masking)
- [Running the Transformer Decoder](#running-the-transformer-decoder)
- [Contributing](#contributing)
- [License](#license)
## Overview
This project is a tutorial for building Transformer models from scratch in **JAX**, with a specific focus on implementing **Decoder-Only Transformers** using **Multi-Query Attention**. Transformers are state-of-the-art models used in various NLP tasks, including language modeling, text generation, and more. Multi-Query Attention (MQA) is an optimized version of multi-head attention, which reduces memory and computational complexity by sharing key and value matrices across all heads.
## Key Concepts
- **Multi-Query Attention**: Shares a single key and value across all attention heads, reducing memory usage and computational overhead compared to traditional multi-head attention.
- **Transformer Decoder Block**: A core component of decoder models, which consists of multi-query attention, a feed-forward network, and residual connections.
- **Causal Masking**: Ensures that each position in the sequence can only attend to itself and previous positions to prevent future token leakage during training.
## Installation
```bash
pip3 install -U jax-transformer
```
### Requirements
- **JAX**: A library for high-performance machine learning research. Install JAX with GPU support (optional) by following the instructions on the [JAX GitHub page](https://github.com/google/jax).
## Usage
After installing the dependencies, you can run the model on random input data to see how the transformer decoder works:
```python
import jax
from jax_transformer.main import transformer_decoder, causal_mask
# Example usage
batch_size = 2
seq_len = 10
dim = 64
heads = 8
d_ff = 256
depth = 6
# Random input tokens
x = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, dim)
)
rng = jax.random.PRNGKey(42)
# Generate causal mask
mask = causal_mask(seq_len)
# Run through transformer decoder
out = transformer_decoder(
x=x,
mask=mask,
depth=depth,
heads=heads,
dim=dim,
d_ff=d_ff,
dropout_rate=0.1,
rng=rng,
)
print(out.shape) # Should be (batch_size, seq_len, dim)
```
## Code Walkthrough
This section explains the key components of the model in detail.
### Multi-Query Attention
The **Multi-Query Attention** mechanism replaces the traditional multi-head attention by sharing the same set of key-value pairs for all heads while keeping separate query projections. This drastically reduces the memory footprint and computation.
```python
def multi_query_attention(query, key, value, mask):
...
```
### Feed-Forward Layer
After the attention mechanism, the transformer applies a two-layer feed-forward network with a ReLU activation in between. This allows the model to add depth and capture complex patterns.
```python
def feed_forward(x, d_ff):
...
```
### Transformer Decoder Block
The **Transformer Decoder Block** combines the multi-query attention mechanism with the feed-forward network and adds **residual connections** and **layer normalization** to stabilize the learning process. It processes sequences in a causal manner, meaning that tokens can only attend to previous tokens, which is crucial for auto-regressive models (e.g., language models).
```python
def transformer_decoder_block(x, key, value, mask, num_heads, d_model, d_ff):
...
```
### Causal Masking
The **Causal Mask** ensures that during training or inference, tokens in the sequence can only attend to themselves or previous tokens. This prevents "future leakage" and is crucial for tasks such as language modeling and text generation.
```python
def causal_mask(seq_len):
...
```
## Running the Transformer Decoder
To run the decoder model, execute the following script:
```python
python run_transformer.py
```
The model takes random input and runs it through the Transformer decoder stack with multi-query attention. The output shape will be `(batch_size, seq_len, d_model)`.
## Contributing
Contributions are welcome! If you'd like to contribute, please fork the repository and submit a pull request with your improvements. You can also open an issue if you find a bug or want to request a new feature.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
# Citation
```bibtex
@article{JaxTransformer,
author={Kye Gomez},
title={Jax Transformer},
year={2024},
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/kyegomez/JaxTransformer",
"name": "jax-transformer",
"maintainer": null,
"docs_url": null,
"requires_python": "<4.0,>=3.10",
"maintainer_email": null,
"keywords": "artificial intelligence, deep learning, optimizers, Prompt Engineering",
"author": "Kye Gomez",
"author_email": "kye@apac.ai",
"download_url": "https://files.pythonhosted.org/packages/98/98/f92b589a70e4e60d5d3586d2300571ce3a2d8db8d25915c4ef7949860456/jax_transformer-0.0.2.tar.gz",
"platform": null,
"description": "[![Multi-Modality](agorabanner.png)](https://discord.com/servers/agora-999382051935506503)\n\n# Jax Transformer\n[![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/agora-999382051935506503) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@kyegomez3242) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/kye-g-38759a207/) [![Follow on X.com](https://img.shields.io/badge/X.com-Follow-1DA1F2?style=for-the-badge&logo=x&logoColor=white)](https://x.com/kyegomezb)\n\n\nThis repository demonstrates how to build a **Decoder-Only Transformer** with **Multi-Query Attention** in **JAX**. Multi-Query Attention is an efficient variant of the traditional multi-head attention, where all attention heads share the same key-value pairs, but maintain separate query projections.\n\n## Table of Contents\n\n- [Overview](#overview)\n- [Key Concepts](#key-concepts)\n- [Installation](#installation)\n- [Usage](#usage)\n- [Code Walkthrough](#code-walkthrough)\n - [Multi-Query Attention](#multi-query-attention)\n - [Feed-Forward Layer](#feed-forward-layer)\n - [Transformer Decoder Block](#transformer-decoder-block)\n - [Causal Masking](#causal-masking)\n- [Running the Transformer Decoder](#running-the-transformer-decoder)\n- [Contributing](#contributing)\n- [License](#license)\n\n## Overview\n\nThis project is a tutorial for building Transformer models from scratch in **JAX**, with a specific focus on implementing **Decoder-Only Transformers** using **Multi-Query Attention**. Transformers are state-of-the-art models used in various NLP tasks, including language modeling, text generation, and more. Multi-Query Attention (MQA) is an optimized version of multi-head attention, which reduces memory and computational complexity by sharing key and value matrices across all heads.\n\n## Key Concepts\n\n- **Multi-Query Attention**: Shares a single key and value across all attention heads, reducing memory usage and computational overhead compared to traditional multi-head attention.\n- **Transformer Decoder Block**: A core component of decoder models, which consists of multi-query attention, a feed-forward network, and residual connections.\n- **Causal Masking**: Ensures that each position in the sequence can only attend to itself and previous positions to prevent future token leakage during training.\n\n## Installation\n\n```bash\npip3 install -U jax-transformer\n```\n\n### Requirements\n\n- **JAX**: A library for high-performance machine learning research. Install JAX with GPU support (optional) by following the instructions on the [JAX GitHub page](https://github.com/google/jax).\n\n## Usage\n\nAfter installing the dependencies, you can run the model on random input data to see how the transformer decoder works:\n\n```python\nimport jax\nfrom jax_transformer.main import transformer_decoder, causal_mask\n\n# Example usage\nbatch_size = 2\nseq_len = 10\ndim = 64\nheads = 8\nd_ff = 256\ndepth = 6\n\n# Random input tokens\nx = jax.random.normal(\n jax.random.PRNGKey(0), (batch_size, seq_len, dim)\n)\nrng = jax.random.PRNGKey(42)\n# Generate causal mask\nmask = causal_mask(seq_len)\n\n# Run through transformer decoder\nout = transformer_decoder(\n x=x,\n mask=mask,\n depth=depth,\n heads=heads,\n dim=dim,\n d_ff=d_ff,\n dropout_rate=0.1,\n rng=rng,\n)\n\n\nprint(out.shape) # Should be (batch_size, seq_len, dim)\n\n```\n\n## Code Walkthrough\n\nThis section explains the key components of the model in detail.\n\n### Multi-Query Attention\n\nThe **Multi-Query Attention** mechanism replaces the traditional multi-head attention by sharing the same set of key-value pairs for all heads while keeping separate query projections. This drastically reduces the memory footprint and computation.\n\n```python\ndef multi_query_attention(query, key, value, mask):\n ...\n```\n\n### Feed-Forward Layer\n\nAfter the attention mechanism, the transformer applies a two-layer feed-forward network with a ReLU activation in between. This allows the model to add depth and capture complex patterns.\n\n```python\ndef feed_forward(x, d_ff):\n ...\n```\n\n### Transformer Decoder Block\n\nThe **Transformer Decoder Block** combines the multi-query attention mechanism with the feed-forward network and adds **residual connections** and **layer normalization** to stabilize the learning process. It processes sequences in a causal manner, meaning that tokens can only attend to previous tokens, which is crucial for auto-regressive models (e.g., language models).\n\n```python\ndef transformer_decoder_block(x, key, value, mask, num_heads, d_model, d_ff):\n ...\n```\n\n### Causal Masking\n\nThe **Causal Mask** ensures that during training or inference, tokens in the sequence can only attend to themselves or previous tokens. This prevents \"future leakage\" and is crucial for tasks such as language modeling and text generation.\n\n```python\ndef causal_mask(seq_len):\n ...\n```\n\n## Running the Transformer Decoder\n\nTo run the decoder model, execute the following script:\n\n```python\npython run_transformer.py\n```\n\nThe model takes random input and runs it through the Transformer decoder stack with multi-query attention. The output shape will be `(batch_size, seq_len, d_model)`.\n\n## Contributing\n\nContributions are welcome! If you'd like to contribute, please fork the repository and submit a pull request with your improvements. You can also open an issue if you find a bug or want to request a new feature.\n\n## License\n\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.\n\n# Citation\n\n\n```bibtex\n@article{JaxTransformer,\n author={Kye Gomez},\n title={Jax Transformer},\n year={2024},\n}\n```\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Jax Transformer - Jax",
"version": "0.0.2",
"project_urls": {
"Documentation": "https://github.com/kyegomez/JaxTransformer",
"Homepage": "https://github.com/kyegomez/JaxTransformer",
"Repository": "https://github.com/kyegomez/JaxTransformer"
},
"split_keywords": [
"artificial intelligence",
" deep learning",
" optimizers",
" prompt engineering"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "d3f02d2ee7d2c3d62e4c3a1cc694543e8ad666a16b237be2c9d750810e3a4b8f",
"md5": "3e0eb63b203db5bda9b5e83bf45c0420",
"sha256": "265b2ba3c75aab4c877bbd20368324f9bb449c7900f8f1c02350ce209a11bf83"
},
"downloads": -1,
"filename": "jax_transformer-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "3e0eb63b203db5bda9b5e83bf45c0420",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<4.0,>=3.10",
"size": 9552,
"upload_time": "2024-09-08T01:03:06",
"upload_time_iso_8601": "2024-09-08T01:03:06.082078Z",
"url": "https://files.pythonhosted.org/packages/d3/f0/2d2ee7d2c3d62e4c3a1cc694543e8ad666a16b237be2c9d750810e3a4b8f/jax_transformer-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "9898f92b589a70e4e60d5d3586d2300571ce3a2d8db8d25915c4ef7949860456",
"md5": "c62bc4627fab6ddd305fafc76018ffae",
"sha256": "12138c66c8d71f08028293e6da96ac302e6f1a280ddff9b2241811503f195ba2"
},
"downloads": -1,
"filename": "jax_transformer-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "c62bc4627fab6ddd305fafc76018ffae",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<4.0,>=3.10",
"size": 9735,
"upload_time": "2024-09-08T01:03:07",
"upload_time_iso_8601": "2024-09-08T01:03:07.446821Z",
"url": "https://files.pythonhosted.org/packages/98/98/f92b589a70e4e60d5d3586d2300571ce3a2d8db8d25915c4ef7949860456/jax_transformer-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-08 01:03:07",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kyegomez",
"github_project": "JaxTransformer",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "jax-transformer"
}