flash-attention-softmax-n


Nameflash-attention-softmax-n JSON
Version 0.3.2 PyPI version JSON
download
home_pagehttps://github.com/softmax1/Flash-Attention-Softmax-N
SummaryCUDA and Triton implementations of Flash Attention with SoftmaxN.
upload_time2023-11-21 14:15:29
maintainer
docs_urlNone
authorChristopher W. Murphy
requires_python>=3.9
licenseGPLv3
keywords artificial intelligence attention mechanism transformers
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Flash-Attention-Softmax-N

[Flash attention](https://arxiv.org/abs/2205.14135) with softmaxN.
[Attention is Off By One](https://www.evanmiller.org/attention-is-off-by-one.html) hypothesized that using softmax1 in the attention mechanism will reduce the number of outliers in the activations and weights of a transformer model.

🎯**Efficent, Numerically-Stable Implementation of SoftmaxN**: No more worrying about the non-trivial implementation of softmaxN.
$$\text{softmax}_n(x_i) = \frac{\exp(x_i)}{n + \sum_j \exp(x_j)}$$

🚀 **Multiple Attention Implementations, your choice**: Whatever you're aiming for, we've got you covered with three Attention implementations.
In the spirit of the flash attention paper, further gains can be made by considering the whole attention function instead of just the softmaxN subfunction.
- `flash_attention_n`: recommended for integer values of _n_, uses CUDA on the backend if a GPU is available 
- `flash_attention_n_triton`: recommended for non-integer values of _n_ when a GPU is available, uses Triton
- `slow_attention_n`: flexible, torch-based implementation

🧠 **Run statistical analyses**: Compute summary statistics for both the weights and activations of your model.
The activation stats are computed online as the model is training.

🔥 **Perform "surgery" on existing models** Take a pretrained model with softmax_0 in its attention mechanism and "operate" on it to replace softmax_0 with softmax_n.

## Install
Simple installation
```bash
$ pip install flash-attention-softmax-n
```
Optionally install the Triton implementation
```bash
$ pip install flash-attention-softmax-n[triton]
$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```
Optionally install the surgery subpackage for converting pretrained models to softmax_n
```bash
$ pip install flash-attention-softmax-n[surgery]
```

## Usage


|              Feature / Function              | `flash_attention_n` |    `flash_attention_n_triton`    | `slow_attention_n` |
|:--------------------------------------------:|:-------------------:|:--------------------------------:|:------------------:|
|               CPU-compatible?                |         Yes         |                No                |        Yes         |
|          Real or Integer valued $n$          |       Integer       |               Real               |        Real        |
|    Datatype(s) natively supported on GPU     |  fp32, fp16, bf16   |        fp16 (*see below)         |  fp32, fp16, bf16  |
|     Datatypes natively supported on CPU      |     fp32, bf16      |               n/a                |     fp32, bf16     |
|                   Dropout?                   |         Yes         |                No                |        Yes         |
|                 Causal Mask?                 |         Yes         | only tested for $n \leq 10^{-3}$ |        Yes         |
|            Attention Bias (ALiBi)            |         Yes         |                No                |         No         |
|                Attention Mask                |         Yes         |                No                |        Yes         |
|          supports `query.ndim < 4`           |         No          |                No                |        Yes         |
| supports `key.ndim < 4` and `value.ndim < 4` |         Yes         |                No                |        Yes         |
| requries `key.shape[-1] == value.shape[-1]`  |         No          |               Yes                |         No         |

### CUDA
The recommendation function to use for integer-values of _n_ with or without a GPU.
You'll probably need an A100 to reap the full benefit though.
This implementation was inspired by [x-transformers](https://github.com/lucidrains/x-transformers/tree/main).
It uses `torch.nn.functional.scaled_dot_product_attention` on the backend, which requires `torch>=2.0.0`.

```python
import torch
from flash_attention_softmax_n import flash_attention_n

softmax_n_param = 1
query = torch.randn((6, 1, 1024, 64))
key = torch.randn((6, 1152, 64))
value = torch.randn((6, 1152, 32))

attn = flash_attention_n(
    query=query,
    key=key,
    value=value,
    softmax_n_param=softmax_n_param,
    scale=None,
    dropout_p=0.,
    attn_mask=None,
    attn_bias=None,
    is_causal=False
)
```

### Triton
The recommended function to use when you want GPU acceleration and have a non-integer-valued _n_.
Note the Triton implementation has a more limited set of features compared to the CUDA version, see the above comparison table.
*To use datatypes other than `fp16` first convert your input to `fp16` and then convert the attention output back to your original datatype.
This is a generalization of OpenAI's Triton fused attention [implementation](https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py).
Requires `torch>=2.0.0` and `triton>=2.0.0`.

```python
import torch
from flash_attention_softmax_n import flash_attention_n_triton

softmax_n_param = 1.
query = torch.randn((6, 1, 1024, 64))
key = torch.randn((6, 1, 1152, 64))
value = torch.randn((6, 1, 1152, 64))

attn = flash_attention_n_triton(
    query=query,
    key=key,
    value=value,
    softmax_n_param=softmax_n_param,
    scale=None,
    is_causal=False
)
```

### Slow Attention
Written in torch.
Use this version when you have a real-valued _n_, and the Triton version is unavailable or doesn't have the feature(s) you need.

```python
import torch
from flash_attention_softmax_n import slow_attention_n

softmax_n_param = 1.
query = torch.randn((6, 1024, 64))
key = torch.randn((6, 1152, 64))
value = torch.randn((6, 1152, 32))

attn = slow_attention_n(
    query=query,
    key=key,
    value=value,
    softmax_n_param=softmax_n_param,
    scale=None,
    dropout_p=0.,
    attn_mask=None,
    is_causal=False,
    softmax_dtype=None,
    train=True
)
```

We also provide a torch implementation of softmaxN that can be used as a drop-in replacement for softmax.
```python
import torch
from flash_attention_softmax_n import softmax_n

x = torch.rand((100, 100))
# y = torch.nn.functional.softmax(x, dim=-1, dtype=torch.float32)
y = softmax_n(x, dim=-1, dtype=torch.float32)

y1 = softmax_n(x, n=1.)
```

### Statistical Analysis
```python
from flash_attention_softmax_n.analysis import register_activation_hooks, compute_weight_statistics, save_results

model = GPT4()  # XD
activations_statistics = register_activation_hooks(model)  # activation stats are computed online during training, so register the hooks in advance

trainer.train(model)

weight_statistics = compute_weight_statistics(model)  # weights stats are coputed after training is finished

print(activations_statistics['...attention.output...']['kurtosis'])
print(weight_statistics['...attention.output...']['kurtosis'])

save_results({'activations': activations_statistics, 'weights': weight_statistics}, 'my-gpt4')
```

### Surgery
"Operate" on pretrained models to generalize them to softmax_n.
Based on MosaicML's [composer](https://github.com/mosaicml/composer).

Functional API: add one line of code to your script.
```python
import transformers

from flash_attention_softmax_n.surgery import apply_attention_softmax_n


model = transformers.AutoModel.from_pretrained('bert-base-uncased')
apply_attention_softmax_n(model=model, softmax_n_param=1.)
...
```

Object-oriented API for use with the MosaicML composer trainer.
```python
import composer
import transformers

from flash_attention_softmax_n.surgery import AttentionSoftmaxN


model = transformers.AutoModel.from_pretrained('bert-base-uncased')
trainer = composer.trainer.Trainer(
    model=model,
    algorithms=[AttentionSoftmaxN(softmax_n_param=1.)]
)
...
```

Add your model to the registry!
Currently, only BERT, RoBERTa, and XLNet without flash attention are available by default.
As an example, use `policy_registry` to replace slow_attention_0 in `MyModel` with flash_attention_n.
After registration, wrap the model in `apply_attention_softmax_n`.
```python
import types

import torch

from flash_attention_n import slow_attention_n, flash_attention_n
from flash_attention_softmax_n.surgery import apply_attention_softmax_n
from flash_attention_n.surgery.surgery_functions import policy_registry


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = SlowAttention()

    def forward(self, q, k, v):
        return self.attn(q, k, v, softmax_n_param=0.)


class SlowAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q, k, v):
        return slow_attention_n(q, k, v, softmax_n_param=0.)


@policy_registry.register(SlowAttention)
def slow_attention_converter(module: torch.nn.Module, module_index: int, softmax_n_param: float) -> torch.nn.Module:
    assert isinstance(module, SlowAttention)
    del module_index  # unused
    module.n = softmax_n_param
    setattr(module, 'forward', types.MethodType(forward, module))
    return module


def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    return flash_attention_n(q, k, v, softmax_n_param=int(self.n))


if __name__ == '__main__':
    model = MyModel()
    apply_attention_softmax_n(model=model, softmax_n_param=1.)  # will log a warning if the model isn't registered
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/softmax1/Flash-Attention-Softmax-N",
    "name": "flash-attention-softmax-n",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "artificial intelligence,attention mechanism,transformers",
    "author": "Christopher W. Murphy",
    "author_email": "murphtron5000@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/c1/f2/4d5520611214fe18a8163f798532b5b7d43fca191b11a811f0441135cc83/flash-attention-softmax-n-0.3.2.tar.gz",
    "platform": null,
    "description": "# Flash-Attention-Softmax-N\n\n[Flash attention](https://arxiv.org/abs/2205.14135) with softmaxN.\n[Attention is Off By One](https://www.evanmiller.org/attention-is-off-by-one.html) hypothesized that using softmax1 in the attention mechanism will reduce the number of outliers in the activations and weights of a transformer model.\n\n\ud83c\udfaf**Efficent, Numerically-Stable Implementation of SoftmaxN**: No more worrying about the non-trivial implementation of softmaxN.\n$$\\text{softmax}_n(x_i) = \\frac{\\exp(x_i)}{n + \\sum_j \\exp(x_j)}$$\n\n\ud83d\ude80 **Multiple Attention Implementations, your choice**: Whatever you're aiming for, we've got you covered with three Attention implementations.\nIn the spirit of the flash attention paper, further gains can be made by considering the whole attention function instead of just the softmaxN subfunction.\n- `flash_attention_n`: recommended for integer values of _n_, uses CUDA on the backend if a GPU is available \n- `flash_attention_n_triton`: recommended for non-integer values of _n_ when a GPU is available, uses Triton\n- `slow_attention_n`: flexible, torch-based implementation\n\n\ud83e\udde0 **Run statistical analyses**: Compute summary statistics for both the weights and activations of your model.\nThe activation stats are computed online as the model is training.\n\n\ud83d\udd25 **Perform \"surgery\" on existing models** Take a pretrained model with softmax_0 in its attention mechanism and \"operate\" on it to replace softmax_0 with softmax_n.\n\n## Install\nSimple installation\n```bash\n$ pip install flash-attention-softmax-n\n```\nOptionally install the Triton implementation\n```bash\n$ pip install flash-attention-softmax-n[triton]\n$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly\n```\nOptionally install the surgery subpackage for converting pretrained models to softmax_n\n```bash\n$ pip install flash-attention-softmax-n[surgery]\n```\n\n## Usage\n\n\n|              Feature / Function              | `flash_attention_n` |    `flash_attention_n_triton`    | `slow_attention_n` |\n|:--------------------------------------------:|:-------------------:|:--------------------------------:|:------------------:|\n|               CPU-compatible?                |         Yes         |                No                |        Yes         |\n|          Real or Integer valued $n$          |       Integer       |               Real               |        Real        |\n|    Datatype(s) natively supported on GPU     |  fp32, fp16, bf16   |        fp16 (*see below)         |  fp32, fp16, bf16  |\n|     Datatypes natively supported on CPU      |     fp32, bf16      |               n/a                |     fp32, bf16     |\n|                   Dropout?                   |         Yes         |                No                |        Yes         |\n|                 Causal Mask?                 |         Yes         | only tested for $n \\leq 10^{-3}$ |        Yes         |\n|            Attention Bias (ALiBi)            |         Yes         |                No                |         No         |\n|                Attention Mask                |         Yes         |                No                |        Yes         |\n|          supports `query.ndim < 4`           |         No          |                No                |        Yes         |\n| supports `key.ndim < 4` and `value.ndim < 4` |         Yes         |                No                |        Yes         |\n| requries `key.shape[-1] == value.shape[-1]`  |         No          |               Yes                |         No         |\n\n### CUDA\nThe recommendation function to use for integer-values of _n_ with or without a GPU.\nYou'll probably need an A100 to reap the full benefit though.\nThis implementation was inspired by [x-transformers](https://github.com/lucidrains/x-transformers/tree/main).\nIt uses `torch.nn.functional.scaled_dot_product_attention` on the backend, which requires `torch>=2.0.0`.\n\n```python\nimport torch\nfrom flash_attention_softmax_n import flash_attention_n\n\nsoftmax_n_param = 1\nquery = torch.randn((6, 1, 1024, 64))\nkey = torch.randn((6, 1152, 64))\nvalue = torch.randn((6, 1152, 32))\n\nattn = flash_attention_n(\n    query=query,\n    key=key,\n    value=value,\n    softmax_n_param=softmax_n_param,\n    scale=None,\n    dropout_p=0.,\n    attn_mask=None,\n    attn_bias=None,\n    is_causal=False\n)\n```\n\n### Triton\nThe recommended function to use when you want GPU acceleration and have a non-integer-valued _n_.\nNote the Triton implementation has a more limited set of features compared to the CUDA version, see the above comparison table.\n*To use datatypes other than `fp16` first convert your input to `fp16` and then convert the attention output back to your original datatype.\nThis is a generalization of OpenAI's Triton fused attention [implementation](https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py).\nRequires `torch>=2.0.0` and `triton>=2.0.0`.\n\n```python\nimport torch\nfrom flash_attention_softmax_n import flash_attention_n_triton\n\nsoftmax_n_param = 1.\nquery = torch.randn((6, 1, 1024, 64))\nkey = torch.randn((6, 1, 1152, 64))\nvalue = torch.randn((6, 1, 1152, 64))\n\nattn = flash_attention_n_triton(\n    query=query,\n    key=key,\n    value=value,\n    softmax_n_param=softmax_n_param,\n    scale=None,\n    is_causal=False\n)\n```\n\n### Slow Attention\nWritten in torch.\nUse this version when you have a real-valued _n_, and the Triton version is unavailable or doesn't have the feature(s) you need.\n\n```python\nimport torch\nfrom flash_attention_softmax_n import slow_attention_n\n\nsoftmax_n_param = 1.\nquery = torch.randn((6, 1024, 64))\nkey = torch.randn((6, 1152, 64))\nvalue = torch.randn((6, 1152, 32))\n\nattn = slow_attention_n(\n    query=query,\n    key=key,\n    value=value,\n    softmax_n_param=softmax_n_param,\n    scale=None,\n    dropout_p=0.,\n    attn_mask=None,\n    is_causal=False,\n    softmax_dtype=None,\n    train=True\n)\n```\n\nWe also provide a torch implementation of softmaxN that can be used as a drop-in replacement for softmax.\n```python\nimport torch\nfrom flash_attention_softmax_n import softmax_n\n\nx = torch.rand((100, 100))\n# y = torch.nn.functional.softmax(x, dim=-1, dtype=torch.float32)\ny = softmax_n(x, dim=-1, dtype=torch.float32)\n\ny1 = softmax_n(x, n=1.)\n```\n\n### Statistical Analysis\n```python\nfrom flash_attention_softmax_n.analysis import register_activation_hooks, compute_weight_statistics, save_results\n\nmodel = GPT4()  # XD\nactivations_statistics = register_activation_hooks(model)  # activation stats are computed online during training, so register the hooks in advance\n\ntrainer.train(model)\n\nweight_statistics = compute_weight_statistics(model)  # weights stats are coputed after training is finished\n\nprint(activations_statistics['...attention.output...']['kurtosis'])\nprint(weight_statistics['...attention.output...']['kurtosis'])\n\nsave_results({'activations': activations_statistics, 'weights': weight_statistics}, 'my-gpt4')\n```\n\n### Surgery\n\"Operate\" on pretrained models to generalize them to softmax_n.\nBased on MosaicML's [composer](https://github.com/mosaicml/composer).\n\nFunctional API: add one line of code to your script.\n```python\nimport transformers\n\nfrom flash_attention_softmax_n.surgery import apply_attention_softmax_n\n\n\nmodel = transformers.AutoModel.from_pretrained('bert-base-uncased')\napply_attention_softmax_n(model=model, softmax_n_param=1.)\n...\n```\n\nObject-oriented API for use with the MosaicML composer trainer.\n```python\nimport composer\nimport transformers\n\nfrom flash_attention_softmax_n.surgery import AttentionSoftmaxN\n\n\nmodel = transformers.AutoModel.from_pretrained('bert-base-uncased')\ntrainer = composer.trainer.Trainer(\n    model=model,\n    algorithms=[AttentionSoftmaxN(softmax_n_param=1.)]\n)\n...\n```\n\nAdd your model to the registry!\nCurrently, only BERT, RoBERTa, and XLNet without flash attention are available by default.\nAs an example, use `policy_registry` to replace slow_attention_0 in `MyModel` with flash_attention_n.\nAfter registration, wrap the model in `apply_attention_softmax_n`.\n```python\nimport types\n\nimport torch\n\nfrom flash_attention_n import slow_attention_n, flash_attention_n\nfrom flash_attention_softmax_n.surgery import apply_attention_softmax_n\nfrom flash_attention_n.surgery.surgery_functions import policy_registry\n\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.attn = SlowAttention()\n\n    def forward(self, q, k, v):\n        return self.attn(q, k, v, softmax_n_param=0.)\n\n\nclass SlowAttention(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, q, k, v):\n        return slow_attention_n(q, k, v, softmax_n_param=0.)\n\n\n@policy_registry.register(SlowAttention)\ndef slow_attention_converter(module: torch.nn.Module, module_index: int, softmax_n_param: float) -> torch.nn.Module:\n    assert isinstance(module, SlowAttention)\n    del module_index  # unused\n    module.n = softmax_n_param\n    setattr(module, 'forward', types.MethodType(forward, module))\n    return module\n\n\ndef forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n    return flash_attention_n(q, k, v, softmax_n_param=int(self.n))\n\n\nif __name__ == '__main__':\n    model = MyModel()\n    apply_attention_softmax_n(model=model, softmax_n_param=1.)  # will log a warning if the model isn't registered\n```\n",
    "bugtrack_url": null,
    "license": "GPLv3",
    "summary": "CUDA and Triton implementations of Flash Attention with SoftmaxN.",
    "version": "0.3.2",
    "project_urls": {
        "Homepage": "https://github.com/softmax1/Flash-Attention-Softmax-N"
    },
    "split_keywords": [
        "artificial intelligence",
        "attention mechanism",
        "transformers"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "cda6ff3a922ebc2ca1a51e7974112fb685411e18fed718f5eb428242f3048b26",
                "md5": "8637e06aea92f8e2c51b60fe9fdeaffd",
                "sha256": "f41d9dabe136d0c74a35ba247bb88b4f97cb281b4c5c5a1249af0cc790ae9596"
            },
            "downloads": -1,
            "filename": "flash_attention_softmax_n-0.3.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "8637e06aea92f8e2c51b60fe9fdeaffd",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 34627,
            "upload_time": "2023-11-21T14:15:27",
            "upload_time_iso_8601": "2023-11-21T14:15:27.993227Z",
            "url": "https://files.pythonhosted.org/packages/cd/a6/ff3a922ebc2ca1a51e7974112fb685411e18fed718f5eb428242f3048b26/flash_attention_softmax_n-0.3.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c1f24d5520611214fe18a8163f798532b5b7d43fca191b11a811f0441135cc83",
                "md5": "69fabf8d40f247fc9183d65e24a36c89",
                "sha256": "81a488745c58aa4ab6915a49eb24ea5b3e4db72e9454db282f9b618f62bb1ae7"
            },
            "downloads": -1,
            "filename": "flash-attention-softmax-n-0.3.2.tar.gz",
            "has_sig": false,
            "md5_digest": "69fabf8d40f247fc9183d65e24a36c89",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 32525,
            "upload_time": "2023-11-21T14:15:29",
            "upload_time_iso_8601": "2023-11-21T14:15:29.597552Z",
            "url": "https://files.pythonhosted.org/packages/c1/f2/4d5520611214fe18a8163f798532b5b7d43fca191b11a811f0441135cc83/flash-attention-softmax-n-0.3.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-11-21 14:15:29",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "softmax1",
    "github_project": "Flash-Attention-Softmax-N",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "flash-attention-softmax-n"
}
        
Elapsed time: 0.16384s