mgqa


Namemgqa JSON
Version 0.0.5 PyPI version JSON
download
home_pagehttps://github.com/kyegomez/mgqa
Summarymgqa - Pytorch
upload_time2023-09-28 21:41:01
maintainer
docs_urlNone
authorKye Gomez
requires_python>=3.6,<4.0
licenseMIT
keywords artificial intelligence deep learning optimizers prompt engineering
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            [![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# MGQA
The open source implementation of the multi grouped query attention by the paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"


[Paper Link](https://arxiv.org/abs/2305.13245)

# Appreciation
* Lucidrains
* Agorians

# Install
`pip install mgqa`

# Usage
```python
import torch
from mgqa.transformer import MGQATransformer, Decoder

x = torch.randint(0, 20000, (1, 1024))

model = MGQATransformer(e
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8,
        attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head
    )
)

result = model(x)
print(result)
```


# Triton
- A potential triton implementation that may or may not work, I don't have gpus to test this out. If it doesn't work and you fix please let me know so we can provide this useful attn

```python
# !pip3 install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# !pip3 install torch

import torch
import triton
import triton.language as tl


@triton.jit
def max_fn(x, y):
    return tl.math.max(x, y)


@triton.jit
def _fwd_kernel(
    Q, K, V, sm_scale,
    L,
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qvk_offset = off_hz * stride_qh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # scale sm_scale by log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    q = (q * qk_scale).to(tl.float16)
    # loop over k, v and update accumulator
    lo = 0
    hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(K_block_ptr)
        v = tl.load(V_block_ptr)
        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        if IS_CAUSAL:
            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        qk += tl.dot(q, k)
        # -- compute scaling constant ---
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        alpha = tl.math.exp2(m_i - m_i_new)
        p = tl.math.exp2(qk - m_i_new[:, None])
        # -- scale and update acc --
        acc_scale = l_i * 0 + alpha  # workaround some compiler bug
        acc *= acc_scale[:, None]
        acc += tl.dot(p.to(tl.float16), v)
        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
    # write back l and m
    acc = acc / l_i[:, None]
    l_ptrs = L + off_hz * N_CTX + offs_m
    tl.store(l_ptrs, m_i + tl.math.log2(l_i))
    # write back O
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    tl.store(O_block_ptr, acc.to(tl.float16))


@triton.jit
def _bwd_preprocess(
    Out, DO,
    Delta,
    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_n = tl.arange(0, D_HEAD)
    # load
    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    # compute
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_m, delta)


@triton.jit
def _bwd_kernel(
    Q, K, V, sm_scale, Out, DO,
    DQ, DK, DV,
    L,
    D,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    Z, H, N_CTX,
    num_block,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    CAUSAL: tl.constexpr,
):
    off_hz = tl.program_id(0)
    off_z = off_hz // H
    off_h = off_hz % H
    qk_scale = sm_scale * 1.44269504
    # offset pointers for batch/head
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_qz + off_h * stride_qh
    V += off_z * stride_qz + off_h * stride_qh
    DO += off_z * stride_qz + off_h * stride_qh
    DQ += off_z * stride_qz + off_h * stride_qh
    DK += off_z * stride_qz + off_h * stride_qh
    DV += off_z * stride_qz + off_h * stride_qh
    for start_n in range(0, num_block):
        if CAUSAL:
            lo = start_n * BLOCK_M
        else:
            lo = 0
        # initialize row/col offsets
        offs_qm = lo + tl.arange(0, BLOCK_M)
        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
        offs_m = tl.arange(0, BLOCK_N)
        offs_k = tl.arange(0, BLOCK_DMODEL)
        # initialize pointers to value-like data
        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        # pointer to row-wise quantities in value-like data
        D_ptrs = D + off_hz * N_CTX
        l_ptrs = L + off_hz * N_CTX
        # initialize dv amd dk
        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        # k and v stay in SRAM throughout
        k = tl.load(k_ptrs)
        v = tl.load(v_ptrs)
        # loop over rows
        for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
            offs_m_curr = start_m + offs_m
            # load q, k, v, do on-chip
            q = tl.load(q_ptrs)
            # recompute p = softmax(qk, dim=-1).T
            if CAUSAL:
                qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
            else:
                qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk += tl.dot(q, tl.trans(k))
            qk *= qk_scale
            l_i = tl.load(l_ptrs + offs_m_curr)
            p = tl.math.exp2(qk - l_i[:, None])
            # compute dv
            do = tl.load(do_ptrs)
            dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
            # compute dp = dot(v, do)
            Di = tl.load(D_ptrs + offs_m_curr)
            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
            dp += tl.dot(do, tl.trans(v))
            # compute ds = p * (dp - delta[:, None])
            ds = p * dp * sm_scale
            # compute dk = dot(ds.T, q)
            dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
            # compute dq
            dq = tl.load(dq_ptrs)
            dq += tl.dot(ds.to(Q.dtype.element_ty), k)
            tl.store(dq_ptrs, dq)
            # increment pointers
            dq_ptrs += BLOCK_M * stride_qm
            q_ptrs += BLOCK_M * stride_qm
            do_ptrs += BLOCK_M * stride_qm
        # write-back
        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        tl.store(dv_ptrs, dv)
        tl.store(dk_ptrs, dk)


empty = torch.empty(128, device="cuda")


class _mgqa_attention(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx, 
        q, 
        k, 
        v, 
        causal, 
        sm_scale, 
        num_groups
    ):
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        BLOCK_M = 128
        BLOCK_N = 64
        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

        num_warps = 4 if Lk <= 64 else 8
        
        #divide query heads into G groups
        q_groups = torch.chunk(q, num_groups, dim=1)
        k_groups = torch.chunk(k, num_groups, dim=1)
        v_groups = torch.chunk(v, num_groups, dim=1)

        for i in range(num_groups):    
            _fwd_kernel[grid](
                q_groups[i], k_groups[i], v_groups[i], sm_scale,
                L,
                o,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                o.stride(0), o.stride(1), o.stride(2), o.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
                IS_CAUSAL=causal,
                num_warps=num_warps,
                num_stages=4)

        ctx.save_for_backward(q, k, v, o, L)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do, num_groups):
        BLOCK = 128

        q, k, v, o, L = ctx.saved_tensors

        do = do.contiguous()

        dq = torch.zeros_like(q, dtype=torch.float32)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        delta = torch.empty_like(L)

        #divide query heads into G groups
        q_groups = torch.chunk(q, num_groups, dim=1)
        k_groups = torch.chunk(k, num_groups, dim=1)
        v_groups = torch.chunk(v, num_groups, dim=1)
        
        for i in range(num_groups):
            _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
                o, do,
                delta,
                BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
            )
            _bwd_kernel[(ctx.grid[1],)](
                q_groups[i], k_groups[i], v_groups[i], ctx.sm_scale,
                o, do,
                dq, dk, dv,
                L, delta,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                ctx.grid[0],
                BLOCK_M=BLOCK, BLOCK_N=BLOCK,
                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
                CAUSAL=ctx.causal,
                num_stages=1,
            )
        return dq, dk, dv, None, None


attention = _mgqa_attention.apply

# Initialize random inputs
q = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]
k = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]
v = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]

# Set other parameters
causal = False
sm_scale = 0.1
num_groups = 4  # Number of groups to divide the query heads into

# Apply the attention
output = attention(q, k, v, causal, sm_scale, num_groups)

print(output)
```

# License
MIT

# Citations
```biblitex
@misc{2305.13245,
Author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebrón and Sumit Sanghai},
Title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
Year = {2023},
Eprint = {arXiv:2305.13245},
}
```
            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/kyegomez/mgqa",
    "name": "mgqa",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.6,<4.0",
    "maintainer_email": "",
    "keywords": "artificial intelligence,deep learning,optimizers,Prompt Engineering",
    "author": "Kye Gomez",
    "author_email": "kye@apac.ai",
    "download_url": "https://files.pythonhosted.org/packages/ea/0f/29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012/mgqa-0.0.5.tar.gz",
    "platform": null,
    "description": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n# MGQA\nThe open source implementation of the multi grouped query attention by the paper \"GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints\"\n\n\n[Paper Link](https://arxiv.org/abs/2305.13245)\n\n# Appreciation\n* Lucidrains\n* Agorians\n\n# Install\n`pip install mgqa`\n\n# Usage\n```python\nimport torch\nfrom mgqa.transformer import MGQATransformer, Decoder\n\nx = torch.randint(0, 20000, (1, 1024))\n\nmodel = MGQATransformer(e\n    num_tokens = 20000,\n    max_seq_len = 1024,\n    attn_layers = Decoder(\n        dim = 512,\n        depth = 12,\n        heads = 8,\n        attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head\n    )\n)\n\nresult = model(x)\nprint(result)\n```\n\n\n# Triton\n- A potential triton implementation that may or may not work, I don't have gpus to test this out. If it doesn't work and you fix please let me know so we can provide this useful attn\n\n```python\n# !pip3 install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly\n# !pip3 install torch\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef max_fn(x, y):\n    return tl.math.max(x, y)\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q, K, V, sm_scale,\n    L,\n    Out,\n    stride_qz, stride_qh, stride_qm, stride_qk,\n    stride_kz, stride_kh, stride_kn, stride_kk,\n    stride_vz, stride_vh, stride_vk, stride_vn,\n    stride_oz, stride_oh, stride_om, stride_on,\n    Z, H, N_CTX,\n    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_hz = tl.program_id(1)\n    qvk_offset = off_hz * stride_qh\n    Q_block_ptr = tl.make_block_ptr(\n        base=Q + qvk_offset,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_qm, stride_qk),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    K_block_ptr = tl.make_block_ptr(\n        base=K + qvk_offset,\n        shape=(BLOCK_DMODEL, N_CTX),\n        strides=(stride_kk, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_DMODEL, BLOCK_N),\n        order=(0, 1)\n    )\n    V_block_ptr = tl.make_block_ptr(\n        base=V + qvk_offset,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_vk, stride_vn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_N, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n    # scale sm_scale by log_2(e) and use\n    # 2^x instead of exp in the loop because CSE and LICM\n    # don't work as expected with `exp` in the loop\n    qk_scale = sm_scale * 1.44269504\n    # load q: it will stay in SRAM throughout\n    q = tl.load(Q_block_ptr)\n    q = (q * qk_scale).to(tl.float16)\n    # loop over k, v and update accumulator\n    lo = 0\n    hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n    for start_n in range(lo, hi, BLOCK_N):\n        # -- load k, v --\n        k = tl.load(K_block_ptr)\n        v = tl.load(V_block_ptr)\n        # -- compute qk ---\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        if IS_CAUSAL:\n            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n        qk += tl.dot(q, k)\n        # -- compute scaling constant ---\n        m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n        alpha = tl.math.exp2(m_i - m_i_new)\n        p = tl.math.exp2(qk - m_i_new[:, None])\n        # -- scale and update acc --\n        acc_scale = l_i * 0 + alpha  # workaround some compiler bug\n        acc *= acc_scale[:, None]\n        acc += tl.dot(p.to(tl.float16), v)\n        # -- update m_i and l_i --\n        l_i = l_i * alpha + tl.sum(p, 1)\n        m_i = m_i_new\n        # update pointers\n        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n    # write back l and m\n    acc = acc / l_i[:, None]\n    l_ptrs = L + off_hz * N_CTX + offs_m\n    tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n    # write back O\n    O_block_ptr = tl.make_block_ptr(\n        base=Out + qvk_offset,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_om, stride_on),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    tl.store(O_block_ptr, acc.to(tl.float16))\n\n\n@triton.jit\ndef _bwd_preprocess(\n    Out, DO,\n    Delta,\n    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_n = tl.arange(0, D_HEAD)\n    # load\n    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    # compute\n    delta = tl.sum(o * do, axis=1)\n    # write-back\n    tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n    Q, K, V, sm_scale, Out, DO,\n    DQ, DK, DV,\n    L,\n    D,\n    stride_qz, stride_qh, stride_qm, stride_qk,\n    stride_kz, stride_kh, stride_kn, stride_kk,\n    stride_vz, stride_vh, stride_vk, stride_vn,\n    Z, H, N_CTX,\n    num_block,\n    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    CAUSAL: tl.constexpr,\n):\n    off_hz = tl.program_id(0)\n    off_z = off_hz // H\n    off_h = off_hz % H\n    qk_scale = sm_scale * 1.44269504\n    # offset pointers for batch/head\n    Q += off_z * stride_qz + off_h * stride_qh\n    K += off_z * stride_qz + off_h * stride_qh\n    V += off_z * stride_qz + off_h * stride_qh\n    DO += off_z * stride_qz + off_h * stride_qh\n    DQ += off_z * stride_qz + off_h * stride_qh\n    DK += off_z * stride_qz + off_h * stride_qh\n    DV += off_z * stride_qz + off_h * stride_qh\n    for start_n in range(0, num_block):\n        if CAUSAL:\n            lo = start_n * BLOCK_M\n        else:\n            lo = 0\n        # initialize row/col offsets\n        offs_qm = lo + tl.arange(0, BLOCK_M)\n        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n        offs_m = tl.arange(0, BLOCK_N)\n        offs_k = tl.arange(0, BLOCK_DMODEL)\n        # initialize pointers to value-like data\n        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n        v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        # pointer to row-wise quantities in value-like data\n        D_ptrs = D + off_hz * N_CTX\n        l_ptrs = L + off_hz * N_CTX\n        # initialize dv amd dk\n        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n        # k and v stay in SRAM throughout\n        k = tl.load(k_ptrs)\n        v = tl.load(v_ptrs)\n        # loop over rows\n        for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n            offs_m_curr = start_m + offs_m\n            # load q, k, v, do on-chip\n            q = tl.load(q_ptrs)\n            # recompute p = softmax(qk, dim=-1).T\n            if CAUSAL:\n                qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float(\"-inf\"))\n            else:\n                qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n            qk += tl.dot(q, tl.trans(k))\n            qk *= qk_scale\n            l_i = tl.load(l_ptrs + offs_m_curr)\n            p = tl.math.exp2(qk - l_i[:, None])\n            # compute dv\n            do = tl.load(do_ptrs)\n            dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n            # compute dp = dot(v, do)\n            Di = tl.load(D_ptrs + offs_m_curr)\n            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n            dp += tl.dot(do, tl.trans(v))\n            # compute ds = p * (dp - delta[:, None])\n            ds = p * dp * sm_scale\n            # compute dk = dot(ds.T, q)\n            dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n            # compute dq\n            dq = tl.load(dq_ptrs)\n            dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n            tl.store(dq_ptrs, dq)\n            # increment pointers\n            dq_ptrs += BLOCK_M * stride_qm\n            q_ptrs += BLOCK_M * stride_qm\n            do_ptrs += BLOCK_M * stride_qm\n        # write-back\n        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n        tl.store(dv_ptrs, dv)\n        tl.store(dk_ptrs, dk)\n\n\nempty = torch.empty(128, device=\"cuda\")\n\n\nclass _mgqa_attention(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx, \n        q, \n        k, \n        v, \n        causal, \n        sm_scale, \n        num_groups\n    ):\n        # shape constraints\n        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n        assert Lq == Lk and Lk == Lv\n        assert Lk in {16, 32, 64, 128}\n        o = torch.empty_like(q)\n        BLOCK_M = 128\n        BLOCK_N = 64\n        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n\n        num_warps = 4 if Lk <= 64 else 8\n        \n        #divide query heads into G groups\n        q_groups = torch.chunk(q, num_groups, dim=1)\n        k_groups = torch.chunk(k, num_groups, dim=1)\n        v_groups = torch.chunk(v, num_groups, dim=1)\n\n        for i in range(num_groups):    \n            _fwd_kernel[grid](\n                q_groups[i], k_groups[i], v_groups[i], sm_scale,\n                L,\n                o,\n                q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n                k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n                v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n                o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n                q.shape[0], q.shape[1], q.shape[2],\n                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n                IS_CAUSAL=causal,\n                num_warps=num_warps,\n                num_stages=4)\n\n        ctx.save_for_backward(q, k, v, o, L)\n        ctx.grid = grid\n        ctx.sm_scale = sm_scale\n        ctx.BLOCK_DMODEL = Lk\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do, num_groups):\n        BLOCK = 128\n\n        q, k, v, o, L = ctx.saved_tensors\n\n        do = do.contiguous()\n\n        dq = torch.zeros_like(q, dtype=torch.float32)\n        dk = torch.empty_like(k)\n        dv = torch.empty_like(v)\n        delta = torch.empty_like(L)\n\n        #divide query heads into G groups\n        q_groups = torch.chunk(q, num_groups, dim=1)\n        k_groups = torch.chunk(k, num_groups, dim=1)\n        v_groups = torch.chunk(v, num_groups, dim=1)\n        \n        for i in range(num_groups):\n            _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n                o, do,\n                delta,\n                BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n            )\n            _bwd_kernel[(ctx.grid[1],)](\n                q_groups[i], k_groups[i], v_groups[i], ctx.sm_scale,\n                o, do,\n                dq, dk, dv,\n                L, delta,\n                q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n                k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n                v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n                q.shape[0], q.shape[1], q.shape[2],\n                ctx.grid[0],\n                BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n                CAUSAL=ctx.causal,\n                num_stages=1,\n            )\n        return dq, dk, dv, None, None\n\n\nattention = _mgqa_attention.apply\n\n# Initialize random inputs\nq = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]\nk = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]\nv = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]\n\n# Set other parameters\ncausal = False\nsm_scale = 0.1\nnum_groups = 4  # Number of groups to divide the query heads into\n\n# Apply the attention\noutput = attention(q, k, v, causal, sm_scale, num_groups)\n\nprint(output)\n```\n\n# License\nMIT\n\n# Citations\n```biblitex\n@misc{2305.13245,\nAuthor = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr\u00f3n and Sumit Sanghai},\nTitle = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},\nYear = {2023},\nEprint = {arXiv:2305.13245},\n}\n```",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "mgqa - Pytorch",
    "version": "0.0.5",
    "project_urls": {
        "Homepage": "https://github.com/kyegomez/mgqa",
        "Repository": "https://github.com/kyegomez/mgqa"
    },
    "split_keywords": [
        "artificial intelligence",
        "deep learning",
        "optimizers",
        "prompt engineering"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7e0ad5c40ce8c9e6ecba5dc16e24d7ba8f48e63b64f3b42a16f6ef4f6bd8fc55",
                "md5": "fcdbe509e14160c6a90cb211bc6caf71",
                "sha256": "6ed6930710caa7a85408508913cb0529bfd19dc663c21ea20a3c6d8b7fa77da6"
            },
            "downloads": -1,
            "filename": "mgqa-0.0.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "fcdbe509e14160c6a90cb211bc6caf71",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6,<4.0",
            "size": 24847,
            "upload_time": "2023-09-28T21:40:57",
            "upload_time_iso_8601": "2023-09-28T21:40:57.909882Z",
            "url": "https://files.pythonhosted.org/packages/7e/0a/d5c40ce8c9e6ecba5dc16e24d7ba8f48e63b64f3b42a16f6ef4f6bd8fc55/mgqa-0.0.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "ea0f29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012",
                "md5": "c65e32fd82e7346a3228be11b0f03964",
                "sha256": "9acbb77f472724228dfe96288af7101e25c5d90fecc6ed9d417328f766dbc8a7"
            },
            "downloads": -1,
            "filename": "mgqa-0.0.5.tar.gz",
            "has_sig": false,
            "md5_digest": "c65e32fd82e7346a3228be11b0f03964",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6,<4.0",
            "size": 27130,
            "upload_time": "2023-09-28T21:41:01",
            "upload_time_iso_8601": "2023-09-28T21:41:01.117988Z",
            "url": "https://files.pythonhosted.org/packages/ea/0f/29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012/mgqa-0.0.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-09-28 21:41:01",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "kyegomez",
    "github_project": "mgqa",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "mgqa"
}
        
Elapsed time: 0.37857s