# FlashAttention2 with Custom Masks ðŸŽ
**Note: This is an unofficial implementation of FlashAttention2.**
For efficiency purposes, the standard implementations of FlashAttention currently do not support **arbitrary custom masks**.
Their implementation of specific masks like causal masking for language modeling are implemented using branch logic to save memory. This repository is just a modified version of the tutorial Triton implementation of FlashAttention2 that allows the user
to define a (batch of) custom mask. It modifies both the forward and backwards pass to handle custom masking (you can define a different mask per head and batch).
Original Triton code: [https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html)
See the original thread: [https://github.com/Dao-AILab/flash-attention/issues/352](https://github.com/Dao-AILab/flash-attention/issues/352)
## Example Setup
The relevant libraries needed to use the custom-mask FlashAttention2 kernel are below:
```
pip install triton>=3.0.0
pip install torch
```
#### For Viewing Benchmarking Results
Other libraries for evaluating the performance of the models is below. These are primarily for `test_benchmark.py`, which verifies the correctness of the implementation.
```
pip install pytest
pip install matplotlib
pip install pandas
```
To compare with the official FlashAttention and `xformers.ops.memory_efficient_attention` implementations, make sure to install both libraries separately (follow the instructions on these repositories).
```
pip install flash-attn --no-build-isolation
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
```
## Testing Correctness
There are two `pytest` functions in `test_benchmark.py`, one that tests whether a reference implementation of multi-head attention with a causal mask matches the Triton version in both the forward pass and backwards pass gradients. The second tests whether the same implementation with **random** masks matches the Triton version. You can modify these tests to do more rigorous correctness tests and check with `pytest`.
## Simple Example
You can insert this module into your standard attention pipeline.
```python
from fa2_custom_mask import flash_attention_custom_mask
B, H, L, D = 4, 16, 4096, 64
sm_scale = 1 / (D ** 0.5)
fp32_q = torch.randn(B, H, L, D).float().cuda()
fp32_k = torch.randn(B, H, L, D).float().cuda()
fp32_v = torch.randn(B, H, L, D).float().cuda()
mask = torch.randint(0, 2, (B, 1, L, L)).int().cuda()
mask = torch.broadcast_to(mask, (B, H, L, L))
out = flash_attention_custom_mask(fp32_q, fp32_k, fp32_v, mask=mask, sm_scale=sm_scale)
...
out.backward(loss)
```
## Benchmarking
Simple benchmark against the base Triton implementation. In our custom mask version, we pass in the canonical causal mask as input (hence storing in global device memory). Running `test_benchmark.py`,
with batch size=4, # heads=16, hidden dim=64, and sequence length `N_CTX` ranging from 256 to 16384 in powers of 2. You can replicate the experiments by running
```
pytest
python test_benchmark.py
```
#### Causal Masks and No Masks Comparisons
We compare against the original experiments and original implementation, as well as the official FlashAttention and xformers implementation (note: there seems to be a versioning issue, so it's using a different implementation. I corrected the version in the later benchmarking experiments).
![causal and no masking with flash attn](./data/results-causal-fa.png)
#### Causal Masks and No Masks Comparisons (with Correct xfrormers version)
We compare against the original experiments and original implementation, as well as the xformers implementation. Notably, the original implementation does well for causal masking because of some pipelining tricks and ability to not have to store masks.
![causal and no masking](./data/results-causal.png)
#### Custom Masking Comparison
We compare directly to the [xformers memory efficient attention](https://facebookresearch.github.io/xformers/components/ops.html) which allows for custom masking. We generate random masks (fixed across the head dimension).
![custom masking](./data/results-random.png)
## Notes and Bugs
1. This implementation only works on Ampere devices and up. I originally tried running it on a V100 (Volta) and it failed.
2. You need to be on `triton>=3.0.0`, or it'll complain about permutation indices on the value vector pointer. The `torch` and `flash-attn` libraries may force you to install `triton=2.x.x`, but you can just re-install `triton>=3.0.0` and it should work. I may fix this manually in the future.
* This is oddly specific, but I'm not able to have `flash-attn` and `xformers` at the same time. I had to run them separately and generate the plots.
3. TODO: Add benchmarking for peak memory consumption and other efficiency metrics.
If time permits, I'm interested in making this implementation generalizable / changing the CUDA implementation for FA3 (if it's necessary of course). I also probably will run some more realistic workloads and see what happens.
Raw data
{
"_id": null,
"home_page": null,
"name": "flashattention2-custom-mask",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": null,
"keywords": "flash attention, triton, pytorch",
"author": null,
"author_email": "Alex Zhang <alzhang@alumni.princeton.edu>",
"download_url": "https://files.pythonhosted.org/packages/b0/9c/dd33b745d15ac293e3ab0138835d844a05d6f7728aab7c6d89073d561b0a/flashattention2_custom_mask-0.1.1.tar.gz",
"platform": null,
"description": "# FlashAttention2 with Custom Masks \ud83c\udfad\n**Note: This is an unofficial implementation of FlashAttention2.**\n\nFor efficiency purposes, the standard implementations of FlashAttention currently do not support **arbitrary custom masks**. \nTheir implementation of specific masks like causal masking for language modeling are implemented using branch logic to save memory. This repository is just a modified version of the tutorial Triton implementation of FlashAttention2 that allows the user\nto define a (batch of) custom mask. It modifies both the forward and backwards pass to handle custom masking (you can define a different mask per head and batch).\n \nOriginal Triton code: [https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html)\n\nSee the original thread: [https://github.com/Dao-AILab/flash-attention/issues/352](https://github.com/Dao-AILab/flash-attention/issues/352)\n\n## Example Setup\nThe relevant libraries needed to use the custom-mask FlashAttention2 kernel are below:\n```\npip install triton>=3.0.0\npip install torch\n```\n\n#### For Viewing Benchmarking Results\nOther libraries for evaluating the performance of the models is below. These are primarily for `test_benchmark.py`, which verifies the correctness of the implementation.\n```\npip install pytest\npip install matplotlib\npip install pandas\n```\nTo compare with the official FlashAttention and `xformers.ops.memory_efficient_attention` implementations, make sure to install both libraries separately (follow the instructions on these repositories).\n```\npip install flash-attn --no-build-isolation\npip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121\n```\n\n## Testing Correctness\nThere are two `pytest` functions in `test_benchmark.py`, one that tests whether a reference implementation of multi-head attention with a causal mask matches the Triton version in both the forward pass and backwards pass gradients. The second tests whether the same implementation with **random** masks matches the Triton version. You can modify these tests to do more rigorous correctness tests and check with `pytest`.\n\n## Simple Example\nYou can insert this module into your standard attention pipeline.\n```python\nfrom fa2_custom_mask import flash_attention_custom_mask\n\nB, H, L, D = 4, 16, 4096, 64\nsm_scale = 1 / (D ** 0.5)\n\nfp32_q = torch.randn(B, H, L, D).float().cuda()\nfp32_k = torch.randn(B, H, L, D).float().cuda()\nfp32_v = torch.randn(B, H, L, D).float().cuda()\nmask = torch.randint(0, 2, (B, 1, L, L)).int().cuda()\nmask = torch.broadcast_to(mask, (B, H, L, L))\n\nout = flash_attention_custom_mask(fp32_q, fp32_k, fp32_v, mask=mask, sm_scale=sm_scale)\n...\nout.backward(loss)\n```\n\n## Benchmarking\nSimple benchmark against the base Triton implementation. In our custom mask version, we pass in the canonical causal mask as input (hence storing in global device memory). Running `test_benchmark.py`,\nwith batch size=4, # heads=16, hidden dim=64, and sequence length `N_CTX` ranging from 256 to 16384 in powers of 2. You can replicate the experiments by running\n```\npytest\npython test_benchmark.py\n```\n\n#### Causal Masks and No Masks Comparisons \nWe compare against the original experiments and original implementation, as well as the official FlashAttention and xformers implementation (note: there seems to be a versioning issue, so it's using a different implementation. I corrected the version in the later benchmarking experiments). \n![causal and no masking with flash attn](./data/results-causal-fa.png)\n \n#### Causal Masks and No Masks Comparisons (with Correct xfrormers version)\nWe compare against the original experiments and original implementation, as well as the xformers implementation. Notably, the original implementation does well for causal masking because of some pipelining tricks and ability to not have to store masks.\n![causal and no masking](./data/results-causal.png)\n#### Custom Masking Comparison\nWe compare directly to the [xformers memory efficient attention](https://facebookresearch.github.io/xformers/components/ops.html) which allows for custom masking. We generate random masks (fixed across the head dimension).\n![custom masking](./data/results-random.png)\n\n\n## Notes and Bugs\n1. This implementation only works on Ampere devices and up. I originally tried running it on a V100 (Volta) and it failed. \n2. You need to be on `triton>=3.0.0`, or it'll complain about permutation indices on the value vector pointer. The `torch` and `flash-attn` libraries may force you to install `triton=2.x.x`, but you can just re-install `triton>=3.0.0` and it should work. I may fix this manually in the future.\n * This is oddly specific, but I'm not able to have `flash-attn` and `xformers` at the same time. I had to run them separately and generate the plots.\n3. TODO: Add benchmarking for peak memory consumption and other efficiency metrics.\n\nIf time permits, I'm interested in making this implementation generalizable / changing the CUDA implementation for FA3 (if it's necessary of course). I also probably will run some more realistic workloads and see what happens.\n\n\n\n",
"bugtrack_url": null,
"license": null,
"summary": "Unofficial FlashAttention2 with Custom Masks",
"version": "0.1.1",
"project_urls": null,
"split_keywords": [
"flash attention",
" triton",
" pytorch"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "20be389bab2c3c58d9db88eaf22a7dcf50bf75593cb1ac13548ab6ae16c68522",
"md5": "62466a5180e4e6df0382e6ab3b2d444d",
"sha256": "b2f6aa1b5c5960c97c3993a918d8ef0b0c7c0d0b5344ca8390026a7249e24b07"
},
"downloads": -1,
"filename": "flashattention2_custom_mask-0.1.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "62466a5180e4e6df0382e6ab3b2d444d",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 14478,
"upload_time": "2024-08-14T20:36:31",
"upload_time_iso_8601": "2024-08-14T20:36:31.365681Z",
"url": "https://files.pythonhosted.org/packages/20/be/389bab2c3c58d9db88eaf22a7dcf50bf75593cb1ac13548ab6ae16c68522/flashattention2_custom_mask-0.1.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "b09cdd33b745d15ac293e3ab0138835d844a05d6f7728aab7c6d89073d561b0a",
"md5": "a05f9cc36291f7c41bce3e519b3a513d",
"sha256": "a2efac867f1e018459b3fe1b54b4e0525b811e9ffbc453db1b66f68de5996130"
},
"downloads": -1,
"filename": "flashattention2_custom_mask-0.1.1.tar.gz",
"has_sig": false,
"md5_digest": "a05f9cc36291f7c41bce3e519b3a513d",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 15404,
"upload_time": "2024-08-14T20:36:32",
"upload_time_iso_8601": "2024-08-14T20:36:32.691819Z",
"url": "https://files.pythonhosted.org/packages/b0/9c/dd33b745d15ac293e3ab0138835d844a05d6f7728aab7c6d89073d561b0a/flashattention2_custom_mask-0.1.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-14 20:36:32",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "flashattention2-custom-mask"
}