# dilated-attention-pytorch
(Unofficial) Implementation of `DilatedAttention` from *[LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486)* in PyTorch.
<img src="https://github.com/fkodom/dilated-attention-pytorch/assets/45951340/27304255-e51e-4298-9c7b-5b7e4a51e697" width=800 alt="long-net-sequence-length"/>
## Install
**NOTE**: This library depends on [facebookresearch/xformers](https://github.com/facebookresearch/xformers). If you're not using `torch>=2.0.0`, you may need to install it from source. See their [installation instructions](https://github.com/facebookresearch/xformers#installing-xformers).
PyPI:
```bash
pip install dilated-attention-pytorch
```
From source:
```bash
pip install "dilated-attention-pytorch @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
```
For contributors:
```bash
# Install all dev dependencies (tests etc.)
pip install "dilated-attention-pytorch[all] @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install
```
## Benchmark
I follow the benchmarking procedure from the [LongNet paper](https://arxiv.org/abs/2307.02486) (Section 3.1) as best I can. They tested in a distributed, multi-GPU setting (and by my estimation, with much better GPUs), and I test on a single GTX 2080 Ti, but the same general scaling trends still apply. Rather than 1B tokens, I scale the batch size so that the total number of tokens is 32M, which is the largest sequence that fits in memory on my GPU when running dilated attention.
See: [benchmark.py](./benchmark.py)

> **NOTE**: Clearly, there are some inefficiencies in my `DilatedAttention` implementation for shorter sequence lengths. I'm not sure what's causing this. If you have any insights, please let me know!
## Usage
### `DilatedAttention`
The LongNet paper introduces a new attention mechanism called `DilatedAttention`. It is a drop-in replacement (see below) for "vanilla" attention that allows for much longer sequences to be processed.
> **NOTE**: `DilatedAttention` only supports `batch_first=True`. This is different from "vanilla" attention in PyTorch, which supports both `batch_first=True` and `batch_first=False`.
#### Arguments:
- `segment_lengths` (required, `list[int]`): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.
- `dilation_rates` (required, `list[int]`): Dilation rate for each segment. Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.
```python
import torch
from dilated_attention_pytorch.dilated_attention import DilatedAttention
dilated_attention = DilatedAttention(
segment_lengths=[2048, 4096, 8192],
dilation_rates=[1, 2, 4],
)
# shape: (batch_size, seq_len, num_heads, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
# NOTE: For best performance, use 'dtype=torch.float16' or `dtype=torch.bfloat16`
query = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 8192, 8, 64, device="cuda", dtype=torch.float16)
out = dilated_attention(query, key, value, is_causal=False) # default: causal=False
print(out.shape)
# torch.Size([1, 8192, 8, 64])
```
### `MultiheadDilatedAttention`
`MultiheadDilatedAttention` is a drop-in replacement (see below) for `nn.MultiheadAttention` that uses `DilatedAttention` instead of "vanilla" attention. It also incorporates improvements from the [MAGNETO architecture](https://arxiv.org/abs/2210.06423) (`nn.LayerNorm` placements), as mentioned in the [LongNet paper](https://arxiv.org/abs/2307.02486).
> **NOTE**: `MultiheadDilatedAttention` only supports `batch_first=True`. This is different from `nn.MultiheadAttention`, which supports both `batch_first=True` and `batch_first=False`.
#### Arguments:
- `segment_lengths` (required, `list[int]`): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.
- `dilation_rates` (required, `list[int]`): Dilation rate for each segment. Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.
- Many of the same arguments from `nn.MultiheadAttention`. See the `MultiheadDilatedAttention` class for more details.
```python
from dilated_attention_pytorch.dilated_attention import MultiheadDilatedAttention
device = torch.device("cuda")
dtype = torch.float16
embed_dim = 512
# NOTE: Omitting most of the optional arguments for brevity
mhda = MultiheadDilatedAttention(
embed_dim=embed_dim,
num_heads=8,
segment_lengths=[2048, 4096, 8192],
dilation_rates=[1, 2, 4],
device=device, # optional
dtype=dtype, # optional
)
# shape: (batch_size, seq_len, embed_dim)
# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)
x = torch.randn(1, 8192, embed_dim, device=device, dtype=dtype)
y = mhda(x, x, x, is_causal=False) # default: is_causal=False
print(y.shape)
# torch.Size([1, 8192, 512])
```
### `LongNet`
The [LongNet paper](https://arxiv.org/abs/2307.02486) culminates in a transformer architecture, which can be trained for language modeling with very long context windows. I have implemented two `LongNet` variants, based on the **base** configurations from the paper:
- `LongNetLM` - designed specifically for language modeling
- `LongNet` - a more general encoder-decoder architecture, which is not specific to language modeling
Based on these implementations, it is fairly straightforward to adapt `LongNet` to encoder- or decoder-only architectures, as needed for specific applications.
```python
from dilated_attention_pytorch.long_net import LongNetLM, LongNet
device = torch.device("cuda")
dtype = torch.float16
# NOTE: Showing all default values, which are described in the paper.
net = LongNet(
d_model=768,
nhead=12,
num_encoder_layers=12,
num_decoder_layers=12,
dim_feedforward=3072,
segment_lengths=[2048, 4096, 8192, 16384, 32768],
dilation_rates=[1, 2, 4, 6, 12],
dropout=0.0,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 32768, 768, device=device, dtype=dtype)
with torch.no_grad():
y = net.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, 768])
num_tokens = 10000 # (required) usually obtained from the tokenizer
lm = LongNetLM(
num_tokens=num_tokens,
d_model=768,
nhead=12,
num_encoder_layers=12,
num_decoder_layers=12,
dim_feedforward=3072,
segment_lengths=[2048, 4096, 8192, 16384, 32768],
dilation_rates=[1, 2, 4, 6, 12],
dropout=0.0,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 32768), device=device, dtype=torch.long)
with torch.no_grad():
y = lm.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 32768, num_tokens])
```
Raw data
{
"_id": null,
"home_page": "https://github.com/fkodom/dilated-attention-pytorch",
"name": "dilated-attention-pytorch",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "",
"author": "Frank Odom",
"author_email": "frank.odom.iii@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/53/56/4fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35/dilated-attention-pytorch-0.2.0.tar.gz",
"platform": null,
"description": "# dilated-attention-pytorch\n\n(Unofficial) Implementation of `DilatedAttention` from *[LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486)* in PyTorch.\n\n<img src=\"https://github.com/fkodom/dilated-attention-pytorch/assets/45951340/27304255-e51e-4298-9c7b-5b7e4a51e697\" width=800 alt=\"long-net-sequence-length\"/>\n\n## Install\n\n**NOTE**: This library depends on [facebookresearch/xformers](https://github.com/facebookresearch/xformers). If you're not using `torch>=2.0.0`, you may need to install it from source. See their [installation instructions](https://github.com/facebookresearch/xformers#installing-xformers).\n\nPyPI:\n\n```bash\npip install dilated-attention-pytorch\n```\n\nFrom source:\n```bash\npip install \"dilated-attention-pytorch @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git\"\n```\n\nFor contributors:\n```bash\n# Install all dev dependencies (tests etc.)\npip install \"dilated-attention-pytorch[all] @ git+ssh://git@github.com/fkodom/dilated-attention-pytorch.git\"\n# Setup pre-commit hooks\npre-commit install\n```\n\n\n## Benchmark\n\nI follow the benchmarking procedure from the [LongNet paper](https://arxiv.org/abs/2307.02486) (Section 3.1) as best I can. They tested in a distributed, multi-GPU setting (and by my estimation, with much better GPUs), and I test on a single GTX 2080 Ti, but the same general scaling trends still apply. Rather than 1B tokens, I scale the batch size so that the total number of tokens is 32M, which is the largest sequence that fits in memory on my GPU when running dilated attention.\n\nSee: [benchmark.py](./benchmark.py)\n\n\n\n> **NOTE**: Clearly, there are some inefficiencies in my `DilatedAttention` implementation for shorter sequence lengths. I'm not sure what's causing this. If you have any insights, please let me know!\n\n\n## Usage\n\n### `DilatedAttention`\n\nThe LongNet paper introduces a new attention mechanism called `DilatedAttention`. It is a drop-in replacement (see below) for \"vanilla\" attention that allows for much longer sequences to be processed.\n\n> **NOTE**: `DilatedAttention` only supports `batch_first=True`. This is different from \"vanilla\" attention in PyTorch, which supports both `batch_first=True` and `batch_first=False`. \n\n#### Arguments:\n- `segment_lengths` (required, `list[int]`): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.\n- `dilation_rates` (required, `list[int]`): Dilation rate for each segment. Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.\n\n\n```python\nimport torch\nfrom dilated_attention_pytorch.dilated_attention import DilatedAttention\n\ndilated_attention = DilatedAttention(\n segment_lengths=[2048, 4096, 8192],\n dilation_rates=[1, 2, 4],\n)\n\n# shape: (batch_size, seq_len, num_heads, embed_dim)\n# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)\n# NOTE: For best performance, use 'dtype=torch.float16' or `dtype=torch.bfloat16`\nquery = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\nkey = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\nvalue = torch.randn(1, 8192, 8, 64, device=\"cuda\", dtype=torch.float16)\n\nout = dilated_attention(query, key, value, is_causal=False) # default: causal=False\nprint(out.shape)\n# torch.Size([1, 8192, 8, 64])\n```\n\n\n### `MultiheadDilatedAttention`\n\n`MultiheadDilatedAttention` is a drop-in replacement (see below) for `nn.MultiheadAttention` that uses `DilatedAttention` instead of \"vanilla\" attention. It also incorporates improvements from the [MAGNETO architecture](https://arxiv.org/abs/2210.06423) (`nn.LayerNorm` placements), as mentioned in the [LongNet paper](https://arxiv.org/abs/2307.02486).\n\n> **NOTE**: `MultiheadDilatedAttention` only supports `batch_first=True`. This is different from `nn.MultiheadAttention`, which supports both `batch_first=True` and `batch_first=False`.\n\n#### Arguments:\n- `segment_lengths` (required, `list[int]`): Length of each attention segment. This is usually a geometric sequence increasing in powers of 2, such as `[2048, 4096, 8192]`.\n- `dilation_rates` (required, `list[int]`): Dilation rate for each segment. Like with `segment_lengths`, this is usually a geometric sequence increasing in powers of 2, such as `[1, 2, 4]`.\n- Many of the same arguments from `nn.MultiheadAttention`. See the `MultiheadDilatedAttention` class for more details.\n\n```python\nfrom dilated_attention_pytorch.dilated_attention import MultiheadDilatedAttention\n\ndevice = torch.device(\"cuda\")\ndtype = torch.float16\nembed_dim = 512\n\n# NOTE: Omitting most of the optional arguments for brevity\nmhda = MultiheadDilatedAttention(\n embed_dim=embed_dim,\n num_heads=8,\n segment_lengths=[2048, 4096, 8192],\n dilation_rates=[1, 2, 4],\n device=device, # optional\n dtype=dtype, # optional\n)\n\n# shape: (batch_size, seq_len, embed_dim)\n# NOTE: 'seq_len' must be a multiple of 8192 (the largest segment length)\nx = torch.randn(1, 8192, embed_dim, device=device, dtype=dtype)\ny = mhda(x, x, x, is_causal=False) # default: is_causal=False\nprint(y.shape)\n# torch.Size([1, 8192, 512])\n```\n\n\n### `LongNet`\n\nThe [LongNet paper](https://arxiv.org/abs/2307.02486) culminates in a transformer architecture, which can be trained for language modeling with very long context windows. I have implemented two `LongNet` variants, based on the **base** configurations from the paper:\n- `LongNetLM` - designed specifically for language modeling\n- `LongNet` - a more general encoder-decoder architecture, which is not specific to language modeling\n\nBased on these implementations, it is fairly straightforward to adapt `LongNet` to encoder- or decoder-only architectures, as needed for specific applications.\n\n```python\nfrom dilated_attention_pytorch.long_net import LongNetLM, LongNet\n\ndevice = torch.device(\"cuda\")\ndtype = torch.float16\n\n# NOTE: Showing all default values, which are described in the paper.\nnet = LongNet(\n d_model=768,\n nhead=12,\n num_encoder_layers=12,\n num_decoder_layers=12,\n dim_feedforward=3072,\n segment_lengths=[2048, 4096, 8192, 16384, 32768],\n dilation_rates=[1, 2, 4, 6, 12],\n dropout=0.0,\n activation=\"relu\",\n layer_norm_eps=1e-5,\n device=device,\n dtype=dtype,\n)\n# shape: (batch_size, seq_len, d_model)\nx = torch.randn(1, 32768, 768, device=device, dtype=dtype)\nwith torch.no_grad():\n y = net.forward(x, is_causal=True) # default: is_causal=True\nprint(y.shape)\n# torch.Size([1, 32768, 768])\n\nnum_tokens = 10000 # (required) usually obtained from the tokenizer\nlm = LongNetLM(\n num_tokens=num_tokens,\n d_model=768,\n nhead=12,\n num_encoder_layers=12,\n num_decoder_layers=12,\n dim_feedforward=3072,\n segment_lengths=[2048, 4096, 8192, 16384, 32768],\n dilation_rates=[1, 2, 4, 6, 12],\n dropout=0.0,\n activation=\"relu\",\n layer_norm_eps=1e-5,\n device=device,\n dtype=dtype,\n)\n# shape: (batch_size, seq_len)\nx = torch.randint(0, num_tokens, (1, 32768), device=device, dtype=torch.long)\nwith torch.no_grad():\n y = lm.forward(x, is_causal=True) # default: is_causal=True\nprint(y.shape)\n# torch.Size([1, 32768, num_tokens])\n```\n",
"bugtrack_url": null,
"license": "",
"summary": "project_description",
"version": "0.2.0",
"project_urls": {
"Homepage": "https://github.com/fkodom/dilated-attention-pytorch"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "ed1db7310d7159240f109d8c10d20eecabfcda3312517012ddf06af108a4d03a",
"md5": "36312bcfcd2584ed955f1679bbb54e23",
"sha256": "049f8d43a6fe00d646cd688626be284d312179291a150b11e98456e7b3b6c95d"
},
"downloads": -1,
"filename": "dilated_attention_pytorch-0.2.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "36312bcfcd2584ed955f1679bbb54e23",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 11271,
"upload_time": "2023-08-03T00:51:27",
"upload_time_iso_8601": "2023-08-03T00:51:27.813582Z",
"url": "https://files.pythonhosted.org/packages/ed/1d/b7310d7159240f109d8c10d20eecabfcda3312517012ddf06af108a4d03a/dilated_attention_pytorch-0.2.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "53564fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35",
"md5": "bbd2082bac08c96ed56b3f93399d7de7",
"sha256": "92891222a0d98205269cf68b92a76b2862fb7982cb6e030ca39a4b4c5df0c7f5"
},
"downloads": -1,
"filename": "dilated-attention-pytorch-0.2.0.tar.gz",
"has_sig": false,
"md5_digest": "bbd2082bac08c96ed56b3f93399d7de7",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 12614,
"upload_time": "2023-08-03T00:51:29",
"upload_time_iso_8601": "2023-08-03T00:51:29.392488Z",
"url": "https://files.pythonhosted.org/packages/53/56/4fbc74792561ff1bb23c9df29b24c1580dcac24647874440e5720d0efd35/dilated-attention-pytorch-0.2.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-08-03 00:51:29",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "fkodom",
"github_project": "dilated-attention-pytorch",
"travis_ci": false,
"coveralls": true,
"github_actions": true,
"lcname": "dilated-attention-pytorch"
}