flash-attn-jax


Nameflash-attn-jax JSON
Version 0.3.0 PyPI version JSON
download
home_pageNone
SummaryFlash Attention port for JAX
upload_time2025-07-28 09:42:37
maintainerNone
docs_urlNone
authorNone
requires_python>=3.11
licenseBSD-3-Clause
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # FlashAttention JAX
This repository provides a jax binding to <https://github.com/Dao-AILab/flash-attention>. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite (see below) and credit FlashAttention if you use it.

## Installation

Requirements:
- CUDA 12.8 and above.
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
- JAX >= `0.5.*`. The custom call api changed in this version.

To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.8
build. CUDA 11 isn't supported any more (since jax stopped supporting it).

### Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use `cibuildwheel` to compile the releases. You could do the same. Something like (for python 3.12):

```sh
git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
```

This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_*.whl`. Or you could build it without docker using `uv build --wheel`. You need cuda installed in that case.

## Usage

Interface: `src/flash_attn_jax/flash.py`

```py
from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
```

This supports multi-query and grouped-query attention (when hk != h). The `softmax_scale` is the multiplier for the softmax, defaulting to `1/sqrt(d)`. Set `window_size` to positive values for sliding window attention.

### Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

```py
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward
```

The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the 
latency hiding scheduler as above.

### GPU support

FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
   GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
   GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "flash-attn-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.11",
    "maintainer_email": null,
    "keywords": null,
    "author": null,
    "author_email": "Tri Dao <tri@tridao.me>, Emily Shepperd <em@zlkj.in>",
    "download_url": null,
    "platform": null,
    "description": "# FlashAttention JAX\nThis repository provides a jax binding to <https://github.com/Dao-AILab/flash-attention>. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.\n\nPlease see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.\n\nFlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).\nPlease cite (see below) and credit FlashAttention if you use it.\n\n## Installation\n\nRequirements:\n- CUDA 12.8 and above.\n- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.\n- JAX >= `0.5.*`. The custom call api changed in this version.\n\nTo install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.8\nbuild. CUDA 11 isn't supported any more (since jax stopped supporting it).\n\n### Installing from source\n\nFlash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use `cibuildwheel` to compile the releases. You could do the same. Something like (for python 3.12):\n\n```sh\ngit clone https://github.com/nshepperd/flash-attn-jax\ncd flash-attn-jax\ncibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?\n```\n\nThis will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_*.whl`. Or you could build it without docker using `uv build --wheel`. You need cuda installed in that case.\n\n## Usage\n\nInterface: `src/flash_attn_jax/flash.py`\n\n```py\nfrom flash_attn_jax import flash_mha\n\n# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]\nflash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))\n```\n\nThis supports multi-query and grouped-query attention (when hk != h). The `softmax_scale` is the multiplier for the softmax, defaulting to `1/sqrt(d)`. Set `window_size` to positive values for sliding window attention.\n\n### Now Supports Ring Attention\n\nUse jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).\n\n```py\nos.environ[\"XLA_FLAGS\"] = '--xla_gpu_enable_latency_hiding_scheduler=true'\n#...\nwith Mesh(devices, axis_names=('len',)) as mesh:\n        sharding = NamedSharding(mesh, P(None,'len')) # n l\n        tokens = jax.device_put(tokens, sharding)\n        # invoke your jax.jit'd transformer.forward\n```\n\nThe latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the \nlatency hiding scheduler as above.\n\n### GPU support\n\nFlashAttention-2 currently supports:\n1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing\n   GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing\n   GPUs for now.\n2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).\n3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.\n",
    "bugtrack_url": null,
    "license": "BSD-3-Clause",
    "summary": "Flash Attention port for JAX",
    "version": "0.3.0",
    "project_urls": {
        "Homepage": "https://github.com/nshepperd/flash_attn_jax"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "f6677bc737873aa2791bbfaca0c83f303bc3040cce3f4cf693740a5161cb3f4a",
                "md5": "8ea66344e983f897d88399a3d3c96d24",
                "sha256": "69b95a05dc75c7ecb350f8e8fe2a232a030d9df2400d31041e5ebec36bbdfdd4"
            },
            "downloads": -1,
            "filename": "flash_attn_jax-0.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
            "has_sig": false,
            "md5_digest": "8ea66344e983f897d88399a3d3c96d24",
            "packagetype": "bdist_wheel",
            "python_version": "cp311",
            "requires_python": ">=3.11",
            "size": 47642032,
            "upload_time": "2025-07-28T09:42:37",
            "upload_time_iso_8601": "2025-07-28T09:42:37.509139Z",
            "url": "https://files.pythonhosted.org/packages/f6/67/7bc737873aa2791bbfaca0c83f303bc3040cce3f4cf693740a5161cb3f4a/flash_attn_jax-0.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "caed00ed11b7b85355c1fc5a9f94c068677a60afb4bd1c771323bc560c02e4a9",
                "md5": "27dbd259a777ab024f6b4b1600df9c0a",
                "sha256": "210c09b05f15f61be4cf0bf6d9f89fa3785fd88c2de00f61b4ecbcfb861401e0"
            },
            "downloads": -1,
            "filename": "flash_attn_jax-0.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
            "has_sig": false,
            "md5_digest": "27dbd259a777ab024f6b4b1600df9c0a",
            "packagetype": "bdist_wheel",
            "python_version": "cp312",
            "requires_python": ">=3.11",
            "size": 47643839,
            "upload_time": "2025-07-28T09:42:57",
            "upload_time_iso_8601": "2025-07-28T09:42:57.186421Z",
            "url": "https://files.pythonhosted.org/packages/ca/ed/00ed11b7b85355c1fc5a9f94c068677a60afb4bd1c771323bc560c02e4a9/flash_attn_jax-0.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-28 09:42:37",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "nshepperd",
    "github_project": "flash_attn_jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "flash-attn-jax"
}
        
Elapsed time: 0.99617s