[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)
# MGQA
The open source implementation of the multi grouped query attention by the paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
[Paper Link](https://arxiv.org/abs/2305.13245)
# Appreciation
* Lucidrains
* Agorians
# Install
`pip install mgqa`
# Usage
```python
import torch
from mgqa.transformer import MGQATransformer, Decoder
x = torch.randint(0, 20000, (1, 1024))
model = MGQATransformer(e
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8,
attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head
)
)
result = model(x)
print(result)
```
# Triton
- A potential triton implementation that may or may not work, I don't have gpus to test this out. If it doesn't work and you fix please let me know so we can provide this useful attn
```python
# !pip3 install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# !pip3 install torch
import torch
import triton
import triton.language as tl
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
qk += tl.dot(q, k)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
l_ptrs = L + off_hz * N_CTX + offs_m
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(tl.float16))
@triton.jit
def _bwd_preprocess(
Out, DO,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
# compute
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
qk_scale = sm_scale * 1.44269504
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
if CAUSAL:
lo = start_n * BLOCK_M
else:
lo = 0
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
empty = torch.empty(128, device="cuda")
class _mgqa_attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
causal,
sm_scale,
num_groups
):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
#divide query heads into G groups
q_groups = torch.chunk(q, num_groups, dim=1)
k_groups = torch.chunk(k, num_groups, dim=1)
v_groups = torch.chunk(v, num_groups, dim=1)
for i in range(num_groups):
_fwd_kernel[grid](
q_groups[i], k_groups[i], v_groups[i], sm_scale,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do, num_groups):
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
#divide query heads into G groups
q_groups = torch.chunk(q, num_groups, dim=1)
k_groups = torch.chunk(k, num_groups, dim=1)
v_groups = torch.chunk(v, num_groups, dim=1)
for i in range(num_groups):
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q_groups[i], k_groups[i], v_groups[i], ctx.sm_scale,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal,
num_stages=1,
)
return dq, dk, dv, None, None
attention = _mgqa_attention.apply
# Initialize random inputs
q = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]
k = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]
v = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]
# Set other parameters
causal = False
sm_scale = 0.1
num_groups = 4 # Number of groups to divide the query heads into
# Apply the attention
output = attention(q, k, v, causal, sm_scale, num_groups)
print(output)
```
# License
MIT
# Citations
```biblitex
@misc{2305.13245,
Author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebrón and Sumit Sanghai},
Title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
Year = {2023},
Eprint = {arXiv:2305.13245},
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/kyegomez/mgqa",
"name": "mgqa",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.6,<4.0",
"maintainer_email": "",
"keywords": "artificial intelligence,deep learning,optimizers,Prompt Engineering",
"author": "Kye Gomez",
"author_email": "kye@apac.ai",
"download_url": "https://files.pythonhosted.org/packages/ea/0f/29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012/mgqa-0.0.5.tar.gz",
"platform": null,
"description": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n# MGQA\nThe open source implementation of the multi grouped query attention by the paper \"GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints\"\n\n\n[Paper Link](https://arxiv.org/abs/2305.13245)\n\n# Appreciation\n* Lucidrains\n* Agorians\n\n# Install\n`pip install mgqa`\n\n# Usage\n```python\nimport torch\nfrom mgqa.transformer import MGQATransformer, Decoder\n\nx = torch.randint(0, 20000, (1, 1024))\n\nmodel = MGQATransformer(e\n num_tokens = 20000,\n max_seq_len = 1024,\n attn_layers = Decoder(\n dim = 512,\n depth = 12,\n heads = 8,\n attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head\n )\n)\n\nresult = model(x)\nprint(result)\n```\n\n\n# Triton\n- A potential triton implementation that may or may not work, I don't have gpus to test this out. If it doesn't work and you fix please let me know so we can provide this useful attn\n\n```python\n# !pip3 install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly\n# !pip3 install torch\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n # loop over k, v and update accumulator\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n # write back l and m\n acc = acc / l_i[:, None]\n l_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n # write back O\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(tl.float16))\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO,\n Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n # compute\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n qk_scale = sm_scale * 1.44269504\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n if CAUSAL:\n lo = start_n * BLOCK_M\n else:\n lo = 0\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n l_ptrs = L + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n if CAUSAL:\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float(\"-inf\"))\n else:\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk *= qk_scale\n l_i = tl.load(l_ptrs + offs_m_curr)\n p = tl.math.exp2(qk - l_i[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n # compute dq\n dq = tl.load(dq_ptrs)\n dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n tl.store(dq_ptrs, dq)\n # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nempty = torch.empty(128, device=\"cuda\")\n\n\nclass _mgqa_attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx, \n q, \n k, \n v, \n causal, \n sm_scale, \n num_groups\n ):\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n BLOCK_M = 128\n BLOCK_N = 64\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n\n num_warps = 4 if Lk <= 64 else 8\n \n #divide query heads into G groups\n q_groups = torch.chunk(q, num_groups, dim=1)\n k_groups = torch.chunk(k, num_groups, dim=1)\n v_groups = torch.chunk(v, num_groups, dim=1)\n\n for i in range(num_groups): \n _fwd_kernel[grid](\n q_groups[i], k_groups[i], v_groups[i], sm_scale,\n L,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n IS_CAUSAL=causal,\n num_warps=num_warps,\n num_stages=4)\n\n ctx.save_for_backward(q, k, v, o, L)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do, num_groups):\n BLOCK = 128\n\n q, k, v, o, L = ctx.saved_tensors\n\n do = do.contiguous()\n\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n delta = torch.empty_like(L)\n\n #divide query heads into G groups\n q_groups = torch.chunk(q, num_groups, dim=1)\n k_groups = torch.chunk(k, num_groups, dim=1)\n v_groups = torch.chunk(v, num_groups, dim=1)\n \n for i in range(num_groups):\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do,\n delta,\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q_groups[i], k_groups[i], v_groups[i], ctx.sm_scale,\n o, do,\n dq, dk, dv,\n L, delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n CAUSAL=ctx.causal,\n num_stages=1,\n )\n return dq, dk, dv, None, None\n\n\nattention = _mgqa_attention.apply\n\n# Initialize random inputs\nq = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]\nk = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]\nv = torch.randn(10, 8, 16, 64) # [batch_size, num_heads, seq_length, head_dim]\n\n# Set other parameters\ncausal = False\nsm_scale = 0.1\nnum_groups = 4 # Number of groups to divide the query heads into\n\n# Apply the attention\noutput = attention(q, k, v, causal, sm_scale, num_groups)\n\nprint(output)\n```\n\n# License\nMIT\n\n# Citations\n```biblitex\n@misc{2305.13245,\nAuthor = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr\u00f3n and Sumit Sanghai},\nTitle = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},\nYear = {2023},\nEprint = {arXiv:2305.13245},\n}\n```",
"bugtrack_url": null,
"license": "MIT",
"summary": "mgqa - Pytorch",
"version": "0.0.5",
"project_urls": {
"Homepage": "https://github.com/kyegomez/mgqa",
"Repository": "https://github.com/kyegomez/mgqa"
},
"split_keywords": [
"artificial intelligence",
"deep learning",
"optimizers",
"prompt engineering"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "7e0ad5c40ce8c9e6ecba5dc16e24d7ba8f48e63b64f3b42a16f6ef4f6bd8fc55",
"md5": "fcdbe509e14160c6a90cb211bc6caf71",
"sha256": "6ed6930710caa7a85408508913cb0529bfd19dc663c21ea20a3c6d8b7fa77da6"
},
"downloads": -1,
"filename": "mgqa-0.0.5-py3-none-any.whl",
"has_sig": false,
"md5_digest": "fcdbe509e14160c6a90cb211bc6caf71",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.6,<4.0",
"size": 24847,
"upload_time": "2023-09-28T21:40:57",
"upload_time_iso_8601": "2023-09-28T21:40:57.909882Z",
"url": "https://files.pythonhosted.org/packages/7e/0a/d5c40ce8c9e6ecba5dc16e24d7ba8f48e63b64f3b42a16f6ef4f6bd8fc55/mgqa-0.0.5-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "ea0f29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012",
"md5": "c65e32fd82e7346a3228be11b0f03964",
"sha256": "9acbb77f472724228dfe96288af7101e25c5d90fecc6ed9d417328f766dbc8a7"
},
"downloads": -1,
"filename": "mgqa-0.0.5.tar.gz",
"has_sig": false,
"md5_digest": "c65e32fd82e7346a3228be11b0f03964",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.6,<4.0",
"size": 27130,
"upload_time": "2023-09-28T21:41:01",
"upload_time_iso_8601": "2023-09-28T21:41:01.117988Z",
"url": "https://files.pythonhosted.org/packages/ea/0f/29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012/mgqa-0.0.5.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-09-28 21:41:01",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kyegomez",
"github_project": "mgqa",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "mgqa"
}