jax-flash-attn2


Namejax-flash-attn2 JSON
Version 0.0.1 PyPI version JSON
download
home_pagehttps://github.com/erfanzar/jax-flash-attn2
SummaryFlash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).
upload_time2024-10-23 22:37:19
maintainerNone
docs_urlNone
authorErfan Zare Chavoshi
requires_python>=3.10
licenseApache-2.0
keywords jax deep learning machine learning xla
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # jax-flash-attn2

A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).

## Features

- 🚀 Multiple backend support: GPU, TPU, and CPU
- 🔧 Multiple platform implementations: Triton, Pallas, and JAX
- âš¡ Efficient caching of attention instances
- 🔄 Support for Grouped Query Attention (GQA) and headdims up to 256.
- 📊 JAX sharding-friendly implementation
- 🎯 Automatic platform selection based on backend
- 🧩 Compatible with existing JAX mesh patterns


## Installation

```bash
pip install jax-flash-attn2
```

## Quick Start

```python
from jax_flash_attn2 import get_cached_flash_attention

# Get a cached attention instance
attention = get_cached_flash_attention(
	backend="gpu", # 'gpu', 'tpu', or 'cpu'
	platform="triton", # 'triton', 'pallas', or 'jax'
	blocksize_q=64, # BLOCK SIZE Q
	blocksize_k=128, # BLOCK SIZE K
	softmax_scale=headdim ** -0.5 # Optional scaling factor
)

# Use with your tensors
outputs = attention(
	query=query_states,
	key=key_states,
	value=value_states,
	bias=attention_bias, # Optional
)
```

## Usage with JAX Sharding

```python
with mesh:
	attention_outputs = get_cached_flash_attention(
		backend="gpu",
		platform="triton",
		blocksize_q=128,
		blocksize_k=128,
		softmax_scale=None,
	)(
		query=with_sharding_constraint(query_states, qps).astype(dtype),
		key=with_sharding_constraint(key_states, kps).astype(dtype),
		value=with_sharding_constraint(value_states, vps).astype(dtype),
		bias=with_sharding_constraint(bias, bps).astype(dtype),
	)
```

## Supported Configurations

### Backends
- `gpu`: CUDA-capable GPUs
- `tpu`: Google Cloud TPUs
- `cpu`: CPU fallback

### Platforms
- `triton`: Optimized for NVIDIA GPUs
- `pallas`: Optimized for TPUs and supported on GPUs
- `jax`: Universal fallback, supports all backends

### Valid Backend-Platform Combinations

| Backend | Supported Platforms |
| ------- | ------------------- |
| GPU     | Triton, Pallas, JAX |
| TPU     | Pallas, JAX         |
| CPU     | JAX                 |

## Advanced Configuration

### Custom Block Sizes

```python
attention = get_cached_flash_attention(
    backend="gpu",
    platform="triton",
    blocksize_q=128,    # Customize query block size
    blocksize_k=128,    # Customize key block size
    softmax_scale=1.0,  # Custom softmax scaling
)
```

### Environment Variables

- `FORCE_MHA`: Set to "true", "1", or "on" to force using MHA implementation even for GQA cases

## Performance Tips

1. **Block Sizes**: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.

2. **Platform Selection**:
   - For NVIDIA GPUs: prefer `triton`
   - For TPUs: prefer `pallas`
   - For CPU or fallback: use `jax`

3. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.

## Requirements

- JAX
- einops
- chex
- jax.experimental.pallas (for TPU support)
- triton (for GPU optimized implementation)

## Limitations

- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).

## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
 
## Citation

If you use this implementation in your research, please cite:

```bibtex
@software{jax_flash_attn2,
    title = {JAX Flash Attention 2.0},
    year = {2024},
    url = {https://github.com/erfanzar/jax-flash-attn2}
}
```
## Acknowledgments And Refrences

1. This implementation (MHA) is based on:
- [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)
- JAX ecosystem tools and libraries
- Triton and Pallas optimization frameworks

2. Custom Triton Uses [`JAX-Triton`](https://github.com/jax-ml/jax-triton/)

3. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)
            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/erfanzar/jax-flash-attn2",
    "name": "jax-flash-attn2",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.10",
    "maintainer_email": null,
    "keywords": "JAX, Deep Learning, Machine Learning, XLA",
    "author": "Erfan Zare Chavoshi",
    "author_email": "erfanzare810@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/39/be/7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418/jax_flash_attn2-0.0.1.tar.gz",
    "platform": null,
    "description": "# jax-flash-attn2\n\nA flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).\n\n## Features\n\n- \ud83d\ude80 Multiple backend support: GPU, TPU, and CPU\n- \ud83d\udd27 Multiple platform implementations: Triton, Pallas, and JAX\n- \u26a1 Efficient caching of attention instances\n- \ud83d\udd04 Support for Grouped Query Attention (GQA) and headdims up to 256.\n- \ud83d\udcca JAX sharding-friendly implementation\n- \ud83c\udfaf Automatic platform selection based on backend\n- \ud83e\udde9 Compatible with existing JAX mesh patterns\n\n\n## Installation\n\n```bash\npip install jax-flash-attn2\n```\n\n## Quick Start\n\n```python\nfrom jax_flash_attn2 import get_cached_flash_attention\n\n# Get a cached attention instance\nattention = get_cached_flash_attention(\n\tbackend=\"gpu\", # 'gpu', 'tpu', or 'cpu'\n\tplatform=\"triton\", # 'triton', 'pallas', or 'jax'\n\tblocksize_q=64, # BLOCK SIZE Q\n\tblocksize_k=128, # BLOCK SIZE K\n\tsoftmax_scale=headdim ** -0.5 # Optional scaling factor\n)\n\n# Use with your tensors\noutputs = attention(\n\tquery=query_states,\n\tkey=key_states,\n\tvalue=value_states,\n\tbias=attention_bias, # Optional\n)\n```\n\n## Usage with JAX Sharding\n\n```python\nwith mesh:\n\tattention_outputs = get_cached_flash_attention(\n\t\tbackend=\"gpu\",\n\t\tplatform=\"triton\",\n\t\tblocksize_q=128,\n\t\tblocksize_k=128,\n\t\tsoftmax_scale=None,\n\t)(\n\t\tquery=with_sharding_constraint(query_states, qps).astype(dtype),\n\t\tkey=with_sharding_constraint(key_states, kps).astype(dtype),\n\t\tvalue=with_sharding_constraint(value_states, vps).astype(dtype),\n\t\tbias=with_sharding_constraint(bias, bps).astype(dtype),\n\t)\n```\n\n## Supported Configurations\n\n### Backends\n- `gpu`: CUDA-capable GPUs\n- `tpu`: Google Cloud TPUs\n- `cpu`: CPU fallback\n\n### Platforms\n- `triton`: Optimized for NVIDIA GPUs\n- `pallas`: Optimized for TPUs and supported on GPUs\n- `jax`: Universal fallback, supports all backends\n\n### Valid Backend-Platform Combinations\n\n| Backend | Supported Platforms |\n| ------- | ------------------- |\n| GPU     | Triton, Pallas, JAX |\n| TPU     | Pallas, JAX         |\n| CPU     | JAX                 |\n\n## Advanced Configuration\n\n### Custom Block Sizes\n\n```python\nattention = get_cached_flash_attention(\n    backend=\"gpu\",\n    platform=\"triton\",\n    blocksize_q=128,    # Customize query block size\n    blocksize_k=128,    # Customize key block size\n    softmax_scale=1.0,  # Custom softmax scaling\n)\n```\n\n### Environment Variables\n\n- `FORCE_MHA`: Set to \"true\", \"1\", or \"on\" to force using MHA implementation even for GQA cases\n\n## Performance Tips\n\n1. **Block Sizes**: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.\n\n2. **Platform Selection**:\n   - For NVIDIA GPUs: prefer `triton`\n   - For TPUs: prefer `pallas`\n   - For CPU or fallback: use `jax`\n\n3. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.\n\n## Requirements\n\n- JAX\n- einops\n- chex\n- jax.experimental.pallas (for TPU support)\n- triton (for GPU optimized implementation)\n\n## Limitations\n\n- Triton platform is only available on NVIDIA GPUs.\n- Some platform-backend combinations are not supported (see table above).\n- Custom attention masks are not yet supported (use bias instead).\n\n## Contributing\nContributions are welcome! Please feel free to submit a Pull Request.\n \n## Citation\n\nIf you use this implementation in your research, please cite:\n\n```bibtex\n@software{jax_flash_attn2,\n    title = {JAX Flash Attention 2.0},\n    year = {2024},\n    url = {https://github.com/erfanzar/jax-flash-attn2}\n}\n```\n## Acknowledgments And Refrences\n\n1. This implementation (MHA) is based on:\n- [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)\n- JAX ecosystem tools and libraries\n- Triton and Pallas optimization frameworks\n\n2. Custom Triton Uses [`JAX-Triton`](https://github.com/jax-ml/jax-triton/)\n\n3. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).",
    "version": "0.0.1",
    "project_urls": {
        "Documentation": "https://erfanzar.github.io/jax-flash-attn2",
        "Homepage": "https://github.com/erfanzar/jax-flash-attn2",
        "Repository": "https://github.com/erfanzar/jax-flash-attn2"
    },
    "split_keywords": [
        "jax",
        " deep learning",
        " machine learning",
        " xla"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "33bf6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98",
                "md5": "8d7ca7e9095345343bca1488389d2743",
                "sha256": "161f2baf1bc3a11e80fa30717521769267c5840cabb39af2b5b012f9e1e0ebdb"
            },
            "downloads": -1,
            "filename": "jax_flash_attn2-0.0.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "8d7ca7e9095345343bca1488389d2743",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10",
            "size": 42759,
            "upload_time": "2024-10-23T22:37:12",
            "upload_time_iso_8601": "2024-10-23T22:37:12.917018Z",
            "url": "https://files.pythonhosted.org/packages/33/bf/6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98/jax_flash_attn2-0.0.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "39be7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418",
                "md5": "78211d0cf9ed7c68ca72e7e2279bd22a",
                "sha256": "c76947468451f41d4c9d2fe59c868c13bffdb7e96d05354491567a542f48c815"
            },
            "downloads": -1,
            "filename": "jax_flash_attn2-0.0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "78211d0cf9ed7c68ca72e7e2279bd22a",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10",
            "size": 35406,
            "upload_time": "2024-10-23T22:37:19",
            "upload_time_iso_8601": "2024-10-23T22:37:19.469703Z",
            "url": "https://files.pythonhosted.org/packages/39/be/7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418/jax_flash_attn2-0.0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-23 22:37:19",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "erfanzar",
    "github_project": "jax-flash-attn2",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-flash-attn2"
}
        
Elapsed time: 0.43488s