dilated-attention-pytorch


Namedilated-attention-pytorch JSON
Version 0.2.0 PyPI version JSON
download
home_pagehttps://github.com/fkodom/dilated-attention-pytorch
Summaryproject_description
upload_time2023-08-03 00:51:29
maintainer
docs_urlNone
authorFrank Odom
requires_python>=3.8
license
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage
            # dilated-attention-pytorch

(Unofficial) Implementation of `DilatedAttention` from *[LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486)* in PyTorch.

<img src="https://github.com/fkodom/dilated-attention-pytorch/assets/45951340/27304255-e51e-4298-9c7b-5b7e4a51e697" width=800 alt="long-net-sequence-length"/>

## Install

**NOTE**: This library depends on [facebookresearch/xformers](https://github.com/facebookresearch/xformers).  If you're not using `torch>=2.0.0`, you may need to install it from source.  See their [installation instructions](https://github.com/facebookresearch/xformers#installing-xformers).

PyPI:

```bash
pip install dilated-attention-pytorch
```

From source:
```bash
pip install "dilated-attention-pytorch @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
```

For contributors:
```bash
# Install all dev dependencies (tests etc.)
pip install "dilated-attention-pytorch[all] @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install
```


## Benchmark

I follow the benchmarking procedure from the [LongNet paper](https://arxiv.org/abs/2307.02486) (Section 3.1) as best I can.  They tested in a distributed, multi-GPU setting (and by my estimation, with much better GPUs), and I test on a single GTX 2080 Ti, but the same general scaling trends still apply.  Rather than 1B tokens, I scale the batch size so that the total number of tokens is 32M, which is the largest sequence that fits in memory on my GPU when running dilated attention.

See: [benchmark.py](./benchmark.py)

![benchmark](./doc/benchmark.png)

> **NOTE**: Clearly, there are some inefficiencies in my `DilatedAttention` implementation for shorter sequence lengths.  I'm not sure what's causing this.  If you have any insights, please let me know!


## Usage

### `DilatedAttention`

The LongNet paper introduces a new attention mechanism called `DilatedAttention`.  It is a drop-in replacement (see below) for "vanilla" attention that allows for much longer sequences to be processed.

> **NOTE**: `DilatedAttention` only supports `batch_first=True`.  This is different from "vanilla" attention in PyTorch, which supports both `batch_first=True` and `batch_first=False`. 

#### Arguments:
- `segment_lengths` (required, `list[int]`): Length of each attention segment.  This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.
- `dilation_rates` (required, `list[int]`): Dilation rate for each segment.  Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.


```python
import torch
from dilated_attention_pytorch.dilated_attention import DilatedAttention

dilated_attention = DilatedAttention(
    segment_lengths=[2048, 4096, 8192],
    dilation_rates=[1, 2, 4],
)

# shape: (batch_size, seq_len, num_heads, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
# NOTE: For best performance, use 'dtype=torch.float16' or `dtype=torch.bfloat16`
query = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)

out = dilated_attention(query, key, value, is_causal=False)  # default: causal=False
print(out.shape)
# torch.Size([1, 8192, 8, 64])
```


### `MultiheadDilatedAttention`

`MultiheadDilatedAttention` is a drop-in replacement (see below) for `nn.MultiheadAttention` that uses `DilatedAttention` instead of "vanilla" attention.  It also incorporates improvements from the [MAGNETO architecture](https://arxiv.org/abs/2210.06423) (`nn.LayerNorm` placements), as mentioned in the [LongNet paper](https://arxiv.org/abs/2307.02486).

> **NOTE**: `MultiheadDilatedAttention` only supports `batch_first=True`.  This is different from `nn.MultiheadAttention`, which supports both `batch_first=True` and `batch_first=False`.

#### Arguments:
- `segment_lengths` (required, `list[int]`): Length of each attention segment.  This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.
- `dilation_rates` (required, `list[int]`): Dilation rate for each segment.  Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.
- Many of the same arguments from `nn.MultiheadAttention`.  See the `MultiheadDilatedAttention` class for more details.

```python
from dilated_attention_pytorch.dilated_attention import MultiheadDilatedAttention

device = torch.device("cuda")
dtype = torch.float16
embed_dim = 512

# NOTE: Omitting most of the optional arguments for brevity
mhda = MultiheadDilatedAttention(
    embed_dim=embed_dim,
    num_heads=8,
    segment_lengths=[2048, 4096, 8192],
    dilation_rates=[1, 2, 4],
    device=device,  # optional
    dtype=dtype,  # optional
)

# shape: (batch_size, seq_len, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
x = torch.randn(1, 8192, embed_dim, device=device, dtype=dtype)
y = mhda(x, x, x, is_causal=False)  # default: is_causal=False
print(y.shape)
# torch.Size([1, 8192, 512])
```


### `LongNet`

The [LongNet paper](https://arxiv.org/abs/2307.02486) culminates in a transformer architecture, which can be trained for language modeling with very long context windows.  I have implemented two `LongNet` variants, based on the **base** configurations from the paper:
- `LongNetLM` - designed specifically for language modeling
- `LongNet` - a more general encoder-decoder architecture, which is not specific to language modeling

Based on these implementations, it is fairly straightforward to adapt `LongNet` to encoder- or decoder-only architectures, as needed for specific applications.

```python
from dilated_attention_pytorch.long_net import LongNetLM, LongNet

device = torch.device("cuda")
dtype = torch.float16

# NOTE: Showing all default values, which are described in the paper.
net = LongNet(
    d_model=768,
    nhead=12,
    num_encoder_layers=12,
    num_decoder_layers=12,
    dim_feedforward=3072,
    segment_lengths=[2048, 4096, 8192, 16384, 32768],
    dilation_rates=[1, 2, 4, 6, 12],
    dropout=0.0,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 32768, 768, device=device, dtype=dtype)
with torch.no_grad():
    y = net.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, 768])

num_tokens = 10000  # (required) usually obtained from the tokenizer
lm = LongNetLM(
    num_tokens=num_tokens,
    d_model=768,
    nhead=12,
    num_encoder_layers=12,
    num_decoder_layers=12,
    dim_feedforward=3072,
    segment_lengths=[2048, 4096, 8192, 16384, 32768],
    dilation_rates=[1, 2, 4, 6, 12],
    dropout=0.0,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 32768), device=device, dtype=torch.long)
with torch.no_grad():
    y = lm.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, num_tokens])
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/fkodom/dilated-attention-pytorch",
    "name": "dilated-attention-pytorch",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "",
    "author": "Frank Odom",
    "author_email": "frank.odom.iii@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/53/56/4fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35/dilated-attention-pytorch-0.2.0.tar.gz",
    "platform": null,
    "description": "# dilated-attention-pytorch\n\n(Unofficial) Implementation of `DilatedAttention` from *[LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486)* in PyTorch.\n\n<img src=\"https://github.com/fkodom/dilated-attention-pytorch/assets/45951340/27304255-e51e-4298-9c7b-5b7e4a51e697\" width=800 alt=\"long-net-sequence-length\"/>\n\n## Install\n\n**NOTE**: This library depends on [facebookresearch/xformers](https://github.com/facebookresearch/xformers).  If you're not using `torch>=2.0.0`, you may need to install it from source.  See their [installation instructions](https://github.com/facebookresearch/xformers#installing-xformers).\n\nPyPI:\n\n```bash\npip install dilated-attention-pytorch\n```\n\nFrom source:\n```bash\npip install \"dilated-attention-pytorch @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git\"\n```\n\nFor contributors:\n```bash\n# Install all dev dependencies (tests etc.)\npip install \"dilated-attention-pytorch[all] @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git\"\n# Setup pre-commit hooks\npre-commit install\n```\n\n\n## Benchmark\n\nI follow the benchmarking procedure from the [LongNet paper](https://arxiv.org/abs/2307.02486) (Section 3.1) as best I can.  They tested in a distributed, multi-GPU setting (and by my estimation, with much better GPUs), and I test on a single GTX 2080 Ti, but the same general scaling trends still apply.  Rather than 1B tokens, I scale the batch size so that the total number of tokens is 32M, which is the largest sequence that fits in memory on my GPU when running dilated attention.\n\nSee: [benchmark.py](./benchmark.py)\n\n![benchmark](./doc/benchmark.png)\n\n> **NOTE**: Clearly, there are some inefficiencies in my `DilatedAttention` implementation for shorter sequence lengths.  I'm not sure what's causing this.  If you have any insights, please let me know!\n\n\n## Usage\n\n### `DilatedAttention`\n\nThe LongNet paper introduces a new attention mechanism called `DilatedAttention`.  It is a drop-in replacement (see below) for \"vanilla\" attention that allows for much longer sequences to be processed.\n\n> **NOTE**: `DilatedAttention` only supports `batch_first=True`.  This is different from \"vanilla\" attention in PyTorch, which supports both `batch_first=True` and `batch_first=False`. \n\n#### Arguments:\n- `segment_lengths` (required, `list[int]`): Length of each attention segment.  This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.\n- `dilation_rates` (required, `list[int]`): Dilation rate for each segment.  Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.\n\n\n```python\nimport torch\nfrom dilated_attention_pytorch.dilated_attention import DilatedAttention\n\ndilated_attention = DilatedAttention(\n    segment_lengths=[2048, 4096, 8192],\n    dilation_rates=[1, 2, 4],\n)\n\n# shape: (batch_size, seq_len, num_heads, embed_dim)\n# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)\n# NOTE: For best performance, use 'dtype=torch.float16' or `dtype=torch.bfloat16`\nquery = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\nkey = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\nvalue = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\n\nout = dilated_attention(query, key, value, is_causal=False)  # default: causal=False\nprint(out.shape)\n# torch.Size([1, 8192, 8, 64])\n```\n\n\n### `MultiheadDilatedAttention`\n\n`MultiheadDilatedAttention` is a drop-in replacement (see below) for `nn.MultiheadAttention` that uses `DilatedAttention` instead of \"vanilla\" attention.  It also incorporates improvements from the [MAGNETO architecture](https://arxiv.org/abs/2210.06423) (`nn.LayerNorm` placements), as mentioned in the [LongNet paper](https://arxiv.org/abs/2307.02486).\n\n> **NOTE**: `MultiheadDilatedAttention` only supports `batch_first=True`.  This is different from `nn.MultiheadAttention`, which supports both `batch_first=True` and `batch_first=False`.\n\n#### Arguments:\n- `segment_lengths` (required, `list[int]`): Length of each attention segment.  This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.\n- `dilation_rates` (required, `list[int]`): Dilation rate for each segment.  Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.\n- Many of the same arguments from `nn.MultiheadAttention`.  See the `MultiheadDilatedAttention` class for more details.\n\n```python\nfrom dilated_attention_pytorch.dilated_attention import MultiheadDilatedAttention\n\ndevice = torch.device(\"cuda\")\ndtype = torch.float16\nembed_dim = 512\n\n# NOTE: Omitting most of the optional arguments for brevity\nmhda = MultiheadDilatedAttention(\n    embed_dim=embed_dim,\n    num_heads=8,\n    segment_lengths=[2048, 4096, 8192],\n    dilation_rates=[1, 2, 4],\n    device=device,  # optional\n    dtype=dtype,  # optional\n)\n\n# shape: (batch_size, seq_len, embed_dim)\n# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)\nx = torch.randn(1, 8192, embed_dim, device=device, dtype=dtype)\ny = mhda(x, x, x, is_causal=False)  # default: is_causal=False\nprint(y.shape)\n# torch.Size([1, 8192, 512])\n```\n\n\n### `LongNet`\n\nThe [LongNet paper](https://arxiv.org/abs/2307.02486) culminates in a transformer architecture, which can be trained for language modeling with very long context windows.  I have implemented two `LongNet` variants, based on the **base** configurations from the paper:\n- `LongNetLM` - designed specifically for language modeling\n- `LongNet` - a more general encoder-decoder architecture, which is not specific to language modeling\n\nBased on these implementations, it is fairly straightforward to adapt `LongNet` to encoder- or decoder-only architectures, as needed for specific applications.\n\n```python\nfrom dilated_attention_pytorch.long_net import LongNetLM, LongNet\n\ndevice = torch.device(\"cuda\")\ndtype = torch.float16\n\n# NOTE: Showing all default values, which are described in the paper.\nnet = LongNet(\n    d_model=768,\n    nhead=12,\n    num_encoder_layers=12,\n    num_decoder_layers=12,\n    dim_feedforward=3072,\n    segment_lengths=[2048, 4096, 8192, 16384, 32768],\n    dilation_rates=[1, 2, 4, 6, 12],\n    dropout=0.0,\n    activation=\"relu\",\n    layer_norm_eps=1e-5,\n    device=device,\n    dtype=dtype,\n)\n# shape: (batch_size, seq_len, d_model)\nx = torch.randn(1, 32768, 768, device=device, dtype=dtype)\nwith torch.no_grad():\n    y = net.forward(x, is_causal=True)  # default: is_causal=True\nprint(y.shape)\n# torch.Size([1, 32768, 768])\n\nnum_tokens = 10000  # (required) usually obtained from the tokenizer\nlm = LongNetLM(\n    num_tokens=num_tokens,\n    d_model=768,\n    nhead=12,\n    num_encoder_layers=12,\n    num_decoder_layers=12,\n    dim_feedforward=3072,\n    segment_lengths=[2048, 4096, 8192, 16384, 32768],\n    dilation_rates=[1, 2, 4, 6, 12],\n    dropout=0.0,\n    activation=\"relu\",\n    layer_norm_eps=1e-5,\n    device=device,\n    dtype=dtype,\n)\n# shape: (batch_size, seq_len)\nx = torch.randint(0, num_tokens, (1, 32768), device=device, dtype=torch.long)\nwith torch.no_grad():\n    y = lm.forward(x, is_causal=True)  # default: is_causal=True\nprint(y.shape)\n# torch.Size([1, 32768, num_tokens])\n```\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "project_description",
    "version": "0.2.0",
    "project_urls": {
        "Homepage": "https://github.com/fkodom/dilated-attention-pytorch"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "ed1db7310d7159240f109d8c10d20eecabfcda3312517012ddf06af108a4d03a",
                "md5": "36312bcfcd2584ed955f1679bbb54e23",
                "sha256": "049f8d43a6fe00d646cd688626be284d312179291a150b11e98456e7b3b6c95d"
            },
            "downloads": -1,
            "filename": "dilated_attention_pytorch-0.2.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "36312bcfcd2584ed955f1679bbb54e23",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 11271,
            "upload_time": "2023-08-03T00:51:27",
            "upload_time_iso_8601": "2023-08-03T00:51:27.813582Z",
            "url": "https://files.pythonhosted.org/packages/ed/1d/b7310d7159240f109d8c10d20eecabfcda3312517012ddf06af108a4d03a/dilated_attention_pytorch-0.2.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "53564fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35",
                "md5": "bbd2082bac08c96ed56b3f93399d7de7",
                "sha256": "92891222a0d98205269cf68b92a76b2862fb7982cb6e030ca39a4b4c5df0c7f5"
            },
            "downloads": -1,
            "filename": "dilated-attention-pytorch-0.2.0.tar.gz",
            "has_sig": false,
            "md5_digest": "bbd2082bac08c96ed56b3f93399d7de7",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 12614,
            "upload_time": "2023-08-03T00:51:29",
            "upload_time_iso_8601": "2023-08-03T00:51:29.392488Z",
            "url": "https://files.pythonhosted.org/packages/53/56/4fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35/dilated-attention-pytorch-0.2.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-08-03 00:51:29",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "fkodom",
    "github_project": "dilated-attention-pytorch",
    "travis_ci": false,
    "coveralls": true,
    "github_actions": true,
    "lcname": "dilated-attention-pytorch"
}
        
Elapsed time: 0.52142s