# jax-flash-attn2
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
## Features
- 🚀 Multiple backend support: GPU, TPU, and CPU
- 🔧 Multiple platform implementations: Triton, Pallas, and JAX
- âš¡ Efficient caching of attention instances
- 🔄 Support for Grouped Query Attention (GQA) and headdims up to 256.
- 📊 JAX sharding-friendly implementation
- 🎯 Automatic platform selection based on backend
- 🧩 Compatible with existing JAX mesh patterns
## Installation
```bash
pip install jax-flash-attn2
```
## Quick Start
```python
from jax_flash_attn2 import get_cached_flash_attention
# Get a cached attention instance
attention = get_cached_flash_attention(
backend="gpu", # 'gpu', 'tpu', or 'cpu'
platform="triton", # 'triton', 'pallas', or 'jax'
blocksize_q=64, # BLOCK SIZE Q
blocksize_k=128, # BLOCK SIZE K
softmax_scale=headdim ** -0.5 # Optional scaling factor
)
# Use with your tensors
outputs = attention(
query=query_states,
key=key_states,
value=value_states,
bias=attention_bias, # Optional
)
```
## Usage with JAX Sharding
```python
with mesh:
attention_outputs = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128,
blocksize_k=128,
softmax_scale=None,
)(
query=with_sharding_constraint(query_states, qps).astype(dtype),
key=with_sharding_constraint(key_states, kps).astype(dtype),
value=with_sharding_constraint(value_states, vps).astype(dtype),
bias=with_sharding_constraint(bias, bps).astype(dtype),
)
```
## Supported Configurations
### Backends
- `gpu`: CUDA-capable GPUs
- `tpu`: Google Cloud TPUs
- `cpu`: CPU fallback
### Platforms
- `triton`: Optimized for NVIDIA GPUs
- `pallas`: Optimized for TPUs and supported on GPUs
- `jax`: Universal fallback, supports all backends
### Valid Backend-Platform Combinations
| Backend | Supported Platforms |
| ------- | ------------------- |
| GPU | Triton, Pallas, JAX |
| TPU | Pallas, JAX |
| CPU | JAX |
## Advanced Configuration
### Custom Block Sizes
```python
attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128, # Customize query block size
blocksize_k=128, # Customize key block size
softmax_scale=1.0, # Custom softmax scaling
)
```
### Environment Variables
- `FORCE_MHA`: Set to "true", "1", or "on" to force using MHA implementation even for GQA cases
## Performance Tips
1. **Block Sizes**: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.
2. **Platform Selection**:
- For NVIDIA GPUs: prefer `triton`
- For TPUs: prefer `pallas`
- For CPU or fallback: use `jax`
3. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.
## Requirements
- JAX
- einops
- chex
- jax.experimental.pallas (for TPU support)
- triton (for GPU optimized implementation)
## Limitations
- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## Citation
If you use this implementation in your research, please cite:
```bibtex
@software{jax_flash_attn2,
title = {JAX Flash Attention 2.0},
year = {2024},
url = {https://github.com/erfanzar/jax-flash-attn2}
}
```
## Acknowledgments And Refrences
1. This implementation (MHA) is based on:
- [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)
- JAX ecosystem tools and libraries
- Triton and Pallas optimization frameworks
2. Custom Triton Uses [`JAX-Triton`](https://github.com/jax-ml/jax-triton/)
3. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)
Raw data
{
"_id": null,
"home_page": "https://github.com/erfanzar/jax-flash-attn2",
"name": "jax-flash-attn2",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "JAX, Deep Learning, Machine Learning, XLA",
"author": "Erfan Zare Chavoshi",
"author_email": "erfanzare810@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/39/be/7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418/jax_flash_attn2-0.0.1.tar.gz",
"platform": null,
"description": "# jax-flash-attn2\n\nA flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).\n\n## Features\n\n- \ud83d\ude80 Multiple backend support: GPU, TPU, and CPU\n- \ud83d\udd27 Multiple platform implementations: Triton, Pallas, and JAX\n- \u26a1 Efficient caching of attention instances\n- \ud83d\udd04 Support for Grouped Query Attention (GQA) and headdims up to 256.\n- \ud83d\udcca JAX sharding-friendly implementation\n- \ud83c\udfaf Automatic platform selection based on backend\n- \ud83e\udde9 Compatible with existing JAX mesh patterns\n\n\n## Installation\n\n```bash\npip install jax-flash-attn2\n```\n\n## Quick Start\n\n```python\nfrom jax_flash_attn2 import get_cached_flash_attention\n\n# Get a cached attention instance\nattention = get_cached_flash_attention(\n\tbackend=\"gpu\", # 'gpu', 'tpu', or 'cpu'\n\tplatform=\"triton\", # 'triton', 'pallas', or 'jax'\n\tblocksize_q=64, # BLOCK SIZE Q\n\tblocksize_k=128, # BLOCK SIZE K\n\tsoftmax_scale=headdim ** -0.5 # Optional scaling factor\n)\n\n# Use with your tensors\noutputs = attention(\n\tquery=query_states,\n\tkey=key_states,\n\tvalue=value_states,\n\tbias=attention_bias, # Optional\n)\n```\n\n## Usage with JAX Sharding\n\n```python\nwith mesh:\n\tattention_outputs = get_cached_flash_attention(\n\t\tbackend=\"gpu\",\n\t\tplatform=\"triton\",\n\t\tblocksize_q=128,\n\t\tblocksize_k=128,\n\t\tsoftmax_scale=None,\n\t)(\n\t\tquery=with_sharding_constraint(query_states, qps).astype(dtype),\n\t\tkey=with_sharding_constraint(key_states, kps).astype(dtype),\n\t\tvalue=with_sharding_constraint(value_states, vps).astype(dtype),\n\t\tbias=with_sharding_constraint(bias, bps).astype(dtype),\n\t)\n```\n\n## Supported Configurations\n\n### Backends\n- `gpu`: CUDA-capable GPUs\n- `tpu`: Google Cloud TPUs\n- `cpu`: CPU fallback\n\n### Platforms\n- `triton`: Optimized for NVIDIA GPUs\n- `pallas`: Optimized for TPUs and supported on GPUs\n- `jax`: Universal fallback, supports all backends\n\n### Valid Backend-Platform Combinations\n\n| Backend | Supported Platforms |\n| ------- | ------------------- |\n| GPU | Triton, Pallas, JAX |\n| TPU | Pallas, JAX |\n| CPU | JAX |\n\n## Advanced Configuration\n\n### Custom Block Sizes\n\n```python\nattention = get_cached_flash_attention(\n backend=\"gpu\",\n platform=\"triton\",\n blocksize_q=128, # Customize query block size\n blocksize_k=128, # Customize key block size\n softmax_scale=1.0, # Custom softmax scaling\n)\n```\n\n### Environment Variables\n\n- `FORCE_MHA`: Set to \"true\", \"1\", or \"on\" to force using MHA implementation even for GQA cases\n\n## Performance Tips\n\n1. **Block Sizes**: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.\n\n2. **Platform Selection**:\n - For NVIDIA GPUs: prefer `triton`\n - For TPUs: prefer `pallas`\n - For CPU or fallback: use `jax`\n\n3. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.\n\n## Requirements\n\n- JAX\n- einops\n- chex\n- jax.experimental.pallas (for TPU support)\n- triton (for GPU optimized implementation)\n\n## Limitations\n\n- Triton platform is only available on NVIDIA GPUs.\n- Some platform-backend combinations are not supported (see table above).\n- Custom attention masks are not yet supported (use bias instead).\n\n## Contributing\nContributions are welcome! Please feel free to submit a Pull Request.\n \n## Citation\n\nIf you use this implementation in your research, please cite:\n\n```bibtex\n@software{jax_flash_attn2,\n title = {JAX Flash Attention 2.0},\n year = {2024},\n url = {https://github.com/erfanzar/jax-flash-attn2}\n}\n```\n## Acknowledgments And Refrences\n\n1. This implementation (MHA) is based on:\n- [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)\n- JAX ecosystem tools and libraries\n- Triton and Pallas optimization frameworks\n\n2. Custom Triton Uses [`JAX-Triton`](https://github.com/jax-ml/jax-triton/)\n\n3. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)",
"bugtrack_url": null,
"license": "Apache-2.0",
"summary": "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).",
"version": "0.0.1",
"project_urls": {
"Documentation": "https://erfanzar.github.io/jax-flash-attn2",
"Homepage": "https://github.com/erfanzar/jax-flash-attn2",
"Repository": "https://github.com/erfanzar/jax-flash-attn2"
},
"split_keywords": [
"jax",
" deep learning",
" machine learning",
" xla"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "33bf6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98",
"md5": "8d7ca7e9095345343bca1488389d2743",
"sha256": "161f2baf1bc3a11e80fa30717521769267c5840cabb39af2b5b012f9e1e0ebdb"
},
"downloads": -1,
"filename": "jax_flash_attn2-0.0.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "8d7ca7e9095345343bca1488389d2743",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 42759,
"upload_time": "2024-10-23T22:37:12",
"upload_time_iso_8601": "2024-10-23T22:37:12.917018Z",
"url": "https://files.pythonhosted.org/packages/33/bf/6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98/jax_flash_attn2-0.0.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "39be7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418",
"md5": "78211d0cf9ed7c68ca72e7e2279bd22a",
"sha256": "c76947468451f41d4c9d2fe59c868c13bffdb7e96d05354491567a542f48c815"
},
"downloads": -1,
"filename": "jax_flash_attn2-0.0.1.tar.gz",
"has_sig": false,
"md5_digest": "78211d0cf9ed7c68ca72e7e2279bd22a",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 35406,
"upload_time": "2024-10-23T22:37:19",
"upload_time_iso_8601": "2024-10-23T22:37:19.469703Z",
"url": "https://files.pythonhosted.org/packages/39/be/7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418/jax_flash_attn2-0.0.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-10-23 22:37:19",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "erfanzar",
"github_project": "jax-flash-attn2",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "jax-flash-attn2"
}