[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)
# Multi-Modal Causal Attention
The open source community's implementation of the all-new Multi-Modal Causal Attention from "DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention"
[Paper Link](https://arxiv.org/pdf/2309.14327.pdf)
# Appreciation
* Lucidrains
* Agorians
# Install
`pip install mmca`
# Usage
```python
import torch
from mmca.main import MultiModalCausalAttention
attn = MultiModalCausalAttention(dim=512, heads=8)
x = torch.randn(1, 10, 512)
y = torch.randn(1, 20, 512)
x, y = attn(x, y)
print(x)
print(y)
```
# Architecture
Algorithmic pseudocode
```latex
Input: Visual tokens V, Textual tokens T
Output: Updated Textual tokens T'
1: procedure MMCA(V, T)
2: for each visual token v in V do
3: v' = self_attention(v) // Visual tokens only attend to themselves
4: end for
5: for each textual token t in T do
6: t' = attention(t, T_previous) + attention(t, V) // Textual tokens attend to all their previous tokens AND image tokens
7: end for
8: return T'
9: end procedure
```
# Multi-Modal Causal Attention: A Study
MMCA is a novel attention mechanism designed to handle multi-modal data, i.e., data that comes from different sources or formats, such as text and images. It is an extension of the causal attention mechanism, which is commonly used in transformer models for tasks like language modeling.
## Causal Attention
----------------
Before diving into MMCA, let's first understand the concept of causal attention. In the context of transformers, attention is a measure of how much a model should focus on different parts of the input when producing a particular part of the output.
Causal attention, also known as autoregressive or self-attention, is a type of attention where a token can only attend to previous tokens in the sequence. This is in contrast to other types of attention where a token can attend to all other tokens in the sequence.
The causal attention mechanism can be visualized as follows:
```
Token1 -> |------|
Token2 -> |------|------|
Token3 -> |------|------|------|
Token4 -> |------|------|------|------|
```
Each token can attend to itself and all the tokens before it, but not the ones after it.
----
## Multi-Modal Causal Attention
In a multi-modal setting, we often deal with different types of data simultaneously. For instance, in an image captioning task, the model has to process both image features and textual data. This is where MMCA comes into play.
MMCA extends the concept of causal attention to handle multi-modal data. The key idea behind MMCA is as follows:
1. For visual tokens, they only attend to themselves, as visual tokens are encoded by the visual encoder.
2. For textual tokens, they attend to all their previous tokens. However, they have two separate attention weight matrices for their previous textual tokens and image tokens.
This can be visualized as follows:
```
Visual Tokens:
V1 -> |------|
V2 -> |------|
V3 -> |------|
Textual Tokens:
T1 -> |------|------|------|------|
T2 -> |------|------|------|------|------|
T3 -> |------|------|------|------|------|------|
```
Here, `V1`, `V2`, and `V3` are visual tokens, and `T1`, `T2`, and `T3` are textual tokens. Each visual token only attends to itself, while each textual token attends to all previous textual and visual tokens.
----
## Mathematical Formulation
Let's now delve into the mathematical formulation of MMCA. The attention mechanism in transformers is typically computed using the dot product of query `Q` and key `K` matrices, followed by a softmax operation. In MMCA, we have two separate attention weight matrices for textual and visual tokens.
Let `Q_T` and `K_T` be the query and key matrices for textual tokens, and `Q_V` and `K_V` be the query and key matrices for visual tokens. The attention weights for textual tokens attending to previous textual tokens (`A_TT`) and visual tokens (`A_TV`) can be computed as follows:
```
A_TT = softmax(Q_T * K_T^T)
A_TV = softmax(Q_T * K_V^T)
```
The updated textual token representations can then be computed by applying these attention weights to the value `V` matrices:
```
T' = A_TT * V_T + A_TV * V_V
```
Here, `V_T` and `V_V` are the value matrices for textual and visual tokens, respectively.
## Conclusion
Multi-Modal Causal Attention is a powerful attention mechanism that extends the concept of causal attention to handle multi-modal data. It allows a model to process different types of data simultaneously and in a more efficient manner. By having separate attention weight matrices for different types of tokens, MMCA allows the model to focus on the most relevant parts of the input for each type of token, leading to improved performance on multi-modal tasks.
---
# Todo
* implement flash attention from zeta as the main attn
---
# License
MIT
---
# Citations
```bibtex
@misc{2309.14327,
Author = {Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qi and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He},
Title = {DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention},
Year = {2023},
Eprint = {arXiv:2309.14327},
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/kyegomez/MMCA",
"name": "mmca",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.6,<4.0",
"maintainer_email": "",
"keywords": "artificial intelligence,deep learning,optimizers,Prompt Engineering",
"author": "Kye Gomez",
"author_email": "kye@apac.ai",
"download_url": "https://files.pythonhosted.org/packages/85/7d/e590e3b68976c27102da784e9d5dc70c1ff6d28e7ba1f559ddd8d2b96984/mmca-0.0.4.tar.gz",
"platform": null,
"description": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n# Multi-Modal Causal Attention\nThe open source community's implementation of the all-new Multi-Modal Causal Attention from \"DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention\"\n\n\n[Paper Link](https://arxiv.org/pdf/2309.14327.pdf)\n\n# Appreciation\n* Lucidrains\n* Agorians\n\n\n\n# Install\n`pip install mmca`\n\n# Usage\n```python\nimport torch \nfrom mmca.main import MultiModalCausalAttention\n\n\nattn = MultiModalCausalAttention(dim=512, heads=8)\n\nx = torch.randn(1, 10, 512)\ny = torch.randn(1, 20, 512)\n\nx, y = attn(x, y)\n\nprint(x)\nprint(y)\n```\n\n# Architecture\nAlgorithmic pseudocode\n\n```latex\nInput: Visual tokens V, Textual tokens T\nOutput: Updated Textual tokens T'\n\n1: procedure MMCA(V, T)\n2: for each visual token v in V do\n3: v' = self_attention(v) // Visual tokens only attend to themselves\n4: end for\n5: for each textual token t in T do\n6: t' = attention(t, T_previous) + attention(t, V) // Textual tokens attend to all their previous tokens AND image tokens\n7: end for\n8: return T'\n9: end procedure\n```\n\n# Multi-Modal Causal Attention: A Study\n\nMMCA is a novel attention mechanism designed to handle multi-modal data, i.e., data that comes from different sources or formats, such as text and images. It is an extension of the causal attention mechanism, which is commonly used in transformer models for tasks like language modeling.\n\n## Causal Attention\n----------------\n\nBefore diving into MMCA, let's first understand the concept of causal attention. In the context of transformers, attention is a measure of how much a model should focus on different parts of the input when producing a particular part of the output.\n\nCausal attention, also known as autoregressive or self-attention, is a type of attention where a token can only attend to previous tokens in the sequence. This is in contrast to other types of attention where a token can attend to all other tokens in the sequence.\n\nThe causal attention mechanism can be visualized as follows:\n\n```\nToken1 -> |------|\nToken2 -> |------|------|\nToken3 -> |------|------|------|\nToken4 -> |------|------|------|------|\n\n```\n\nEach token can attend to itself and all the tokens before it, but not the ones after it.\n\n----\n\n## Multi-Modal Causal Attention\n\nIn a multi-modal setting, we often deal with different types of data simultaneously. For instance, in an image captioning task, the model has to process both image features and textual data. This is where MMCA comes into play.\n\nMMCA extends the concept of causal attention to handle multi-modal data. The key idea behind MMCA is as follows:\n\n1. For visual tokens, they only attend to themselves, as visual tokens are encoded by the visual encoder.\n2. For textual tokens, they attend to all their previous tokens. However, they have two separate attention weight matrices for their previous textual tokens and image tokens.\n\nThis can be visualized as follows:\n\n```\nVisual Tokens:\nV1 -> |------|\nV2 -> |------|\nV3 -> |------|\n\nTextual Tokens:\nT1 -> |------|------|------|------|\nT2 -> |------|------|------|------|------|\nT3 -> |------|------|------|------|------|------|\n\n```\n\nHere,\u00a0`V1`,\u00a0`V2`, and\u00a0`V3`\u00a0are visual tokens, and\u00a0`T1`,\u00a0`T2`, and\u00a0`T3`\u00a0are textual tokens. Each visual token only attends to itself, while each textual token attends to all previous textual and visual tokens.\n\n----\n\n## Mathematical Formulation\n\nLet's now delve into the mathematical formulation of MMCA. The attention mechanism in transformers is typically computed using the dot product of query\u00a0`Q`\u00a0and key\u00a0`K`\u00a0matrices, followed by a softmax operation. In MMCA, we have two separate attention weight matrices for textual and visual tokens.\n\nLet\u00a0`Q_T`\u00a0and\u00a0`K_T`\u00a0be the query and key matrices for textual tokens, and\u00a0`Q_V`\u00a0and\u00a0`K_V`\u00a0be the query and key matrices for visual tokens. The attention weights for textual tokens attending to previous textual tokens (`A_TT`) and visual tokens (`A_TV`) can be computed as follows:\n\n```\nA_TT = softmax(Q_T * K_T^T)\nA_TV = softmax(Q_T * K_V^T)\n\n```\n\nThe updated textual token representations can then be computed by applying these attention weights to the value\u00a0`V`\u00a0matrices:\n\n```\nT' = A_TT * V_T + A_TV * V_V\n\n```\n\nHere,\u00a0`V_T`\u00a0and\u00a0`V_V`\u00a0are the value matrices for textual and visual tokens, respectively.\n\n\n## Conclusion\n\nMulti-Modal Causal Attention is a powerful attention mechanism that extends the concept of causal attention to handle multi-modal data. It allows a model to process different types of data simultaneously and in a more efficient manner. By having separate attention weight matrices for different types of tokens, MMCA allows the model to focus on the most relevant parts of the input for each type of token, leading to improved performance on multi-modal tasks.\n\n\n---\n\n# Todo\n* implement flash attention from zeta as the main attn\n---\n\n# License\nMIT\n\n---\n\n# Citations\n```bibtex\n@misc{2309.14327,\nAuthor = {Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qi and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He},\nTitle = {DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention},\nYear = {2023},\nEprint = {arXiv:2309.14327},\n}\n```",
"bugtrack_url": null,
"license": "MIT",
"summary": "MMCA - Pytorch",
"version": "0.0.4",
"project_urls": {
"Homepage": "https://github.com/kyegomez/MMCA",
"Repository": "https://github.com/kyegomez/MMCA"
},
"split_keywords": [
"artificial intelligence",
"deep learning",
"optimizers",
"prompt engineering"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "d11c086d27b1b494f6f410f738ed245e57e181b1ddb920ad78cbcfcf1cbb9fda",
"md5": "57d31cac8e4ea255324de450f90ed77b",
"sha256": "962bc0681b506e0176a114d481a01b4b49c9ac475e972d1024a72cc6877d27f4"
},
"downloads": -1,
"filename": "mmca-0.0.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "57d31cac8e4ea255324de450f90ed77b",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.6,<4.0",
"size": 5135,
"upload_time": "2023-09-27T01:52:17",
"upload_time_iso_8601": "2023-09-27T01:52:17.273814Z",
"url": "https://files.pythonhosted.org/packages/d1/1c/086d27b1b494f6f410f738ed245e57e181b1ddb920ad78cbcfcf1cbb9fda/mmca-0.0.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "857de590e3b68976c27102da784e9d5dc70c1ff6d28e7ba1f559ddd8d2b96984",
"md5": "1490c6376010965f5e7bcd4e8f006ac3",
"sha256": "0b21d1d1b5d81aa781fc86df973dc2c2aca2ef6c15d633be547017d64a60426b"
},
"downloads": -1,
"filename": "mmca-0.0.4.tar.gz",
"has_sig": false,
"md5_digest": "1490c6376010965f5e7bcd4e8f006ac3",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.6,<4.0",
"size": 5571,
"upload_time": "2023-09-27T01:52:18",
"upload_time_iso_8601": "2023-09-27T01:52:18.829826Z",
"url": "https://files.pythonhosted.org/packages/85/7d/e590e3b68976c27102da784e9d5dc70c1ff6d28e7ba1f559ddd8d2b96984/mmca-0.0.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-09-27 01:52:18",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kyegomez",
"github_project": "MMCA",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "mmca"
}