rwkv-ops


Namerwkv-ops JSON
Version 0.2 PyPI version JSON
download
home_pagehttps://github.com/pass-lin/rwkv_ops
SummaryNone
upload_time2025-07-13 03:05:19
maintainerNone
docs_urlNone
authorNone
requires_pythonNone
licenseApache 2.0
keywords rwkv implement for multi backend
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            [English Document](ENREADME.md)

# RWKV OPS 项目

> 由于 RWKV 将持续迭代,核心算子会随之更新。  
> 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。  

### 当前支持
| 算子类型 | 框架支持 |
|----------|----------|
| GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
| 原生算子 | PyTorch、JAX、TensorFlow、NumPy |

> 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。  
> 注意:本库依赖 `keras`。

---

## 安装

```bash
pip install rwkv_ops
```

---

## 环境变量

| 变量名 | 含义 | 取值 | 默认值 | 优先级 |
|---|---|---|---|---|
| `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
| `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | — | — |

> 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。  
> `native` 为原生算子,无 chunkwise,速度慢且显存高。

---

## rwkv7op 使用方法

```python
from rwkv_ops import generalized_delta_rule  # 或 from rwkv_ops import rwkv7_op,完全等价

def generalized_delta_rule(
    r,
    w,
    k,
    v,
    a,
    b,
    initial_state=None,
    output_final_state: bool = True,
    head_first: bool = False,
):
    """
    分块 Delta Rule 注意力接口。

    Args:
        q:  [B, T, H, K]
        k:  [B, T, H, K]
        v:  [B, T, H, V]
        a:  [B, T, H, K]
        b:  [B, T, H, K]
        gk: [B, T, H, K]  # decay term in log space!
        initial_state: 初始状态 [N, H, K, V],N 为序列数
        output_final_state: 是否返回最终状态
        head_first: 是否 head-first 格式,不支持变长

    Returns:
        o:           输出 [B, T, H, V] 或 [B, H, T, V]
        final_state: 最终状态 [N, H, K, V] 或 None
    """
```

### torch-cuda 特殊用法

- torch-cuda 下 `head_size` 也是一个 kernel 参数,默认为 64。  
- 若 `head_size ≠ 64`,请使用:

```python
from rwkv_ops import get_generalized_delta_rule

generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
    your_head_size, KERNEL_TYPE="cuda"
)
```

- `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。  
- 两者 padding 处理逻辑不同:

```python
if padding_mask is not None:
    if RWKV7_USE_KERNEL:
        w += (1 - padding_mask) * -1e9
    else:
        w = w * padding_mask + 1 - padding_mask
```

---

### rwkv7op 实现状态

| Framework   | cuda | triton | native |
|-------------|------|--------|--------|
| PyTorch     | ✅   | ✅     | ✅     |
| JAX         | ❌   | ✅     | ✅     |
| TensorFlow  | ❌   | ❌     | ✅     |
| NumPy       | ❌   | ❌     | ✅     |

---

## rwkv6op 使用方法

### PyTorch 使用注意事项

- 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
- 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
- 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
- PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。

### JAX 使用注意事项

- 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
- 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
- JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
  ```shell
  sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
  ```
- 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
- JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
- JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。

### TensorFlow 使用注意事项

- 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。

---

### 使用方法
需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
```python
from rwkv_ops import RWKV6_OP

operator = RWKV6_OP(
    head_size=64,               # 头大小,不确定时填 64
    max_sequence_length=4096,   # 训练最大序列长度;推理不受限
    ops_loop=False              # 可选:序列长度=1 时是否用上层 API 替代 CUDA
)
```

#### 调用

```python
y, y_state = operator(
    r, k, v, w, u,
    with_state=False,   # 是否使用自定义初始状态 / 输出结束状态
    init_state=None,    # 初始状态 [n_state, num_heads, head_size, head_size]
    state_map=None      # int32 一维数组,长度=batch_size,定义 init_state 映射
)
```

| 参数 | 形状 | 说明 |
|---|---|---|
| r, k, v, w | (batch_size, seq_len, hidden_size) | — |
| u | (num_heads, head_size) 或 (hidden_size,) | — |
| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
| state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |

| 返回值 | 形状 | 说明 |
|---|---|---|
| y | (batch_size, seq_len, hidden_size) | 输出 |
| y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |

---

### 分布式小贴士

- 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
- JAX 需通过 `shard_map` 包装(示例):

```python
import os
os.environ['KERAS_BACKEND'] = 'jax'

import jax, jax.numpy as jnp
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
from functools import partial
from rwkv_ops import RWKV6_OP

batch_size, seq_length = 24, 512
head_size, num_heads = 64, 32
hidden_size = head_size * num_heads

mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
device_ns = NamedSharding(mesh, P('device_axis'))

operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)

@partial(shard_map,
         mesh=mesh,
         in_specs=(P('device_axis'),) * 5,
         out_specs=(P('device_axis'), P('device_axis')),
         check_rep=False)
def call_kernel(r, k, v, w, u):
    # 去掉最外 device 维度
    r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
    y, ys = operator(r, k, v, w, u, with_state=True)
    return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)

# 构造输入并放置到对应设备
keys = jax.random.split(jax.random.PRNGKey(0), 5)
inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
    lambda x: jax.device_put(x, device_ns), inputs)
inputs_u = inputs_u[:, :, 0]  # (devices, hidden_size)

# 可选:jax.jit(call_kernel, ...) 加速
outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)

print(outputs_y.shape, outputs_y.sharding)
print(y_state.shape, y_state.sharding)
```

---

### rwkv6op 实现状态

| Framework   | cuda | triton | native |
|-------------|------|--------|--------|
| PyTorch     | ✅   | ❌     | ✅     |
| JAX         | ⚠️   | ❌     | ✅     |
| TensorFlow  | ❌   | ❌     | ✅     |
| NumPy       | ❌   | ❌     | ✅     |

⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/pass-lin/rwkv_ops",
    "name": "rwkv-ops",
    "maintainer": null,
    "docs_url": null,
    "requires_python": null,
    "maintainer_email": null,
    "keywords": "rwkv implement for multi backend",
    "author": null,
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/0e/f6/82f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122/rwkv_ops-0.2.tar.gz",
    "platform": null,
    "description": "[English Document](ENREADME.md)\n\n# RWKV OPS \u9879\u76ee\n\n> \u7531\u4e8e RWKV \u5c06\u6301\u7eed\u8fed\u4ee3\uff0c\u6838\u5fc3\u7b97\u5b50\u4f1a\u968f\u4e4b\u66f4\u65b0\u3002  \n> \u672c\u4ed3\u4e13\u95e8\u7ef4\u62a4\u300c\u7b97\u5b50\u300d\u672c\u8eab\uff0c\u4e0d\u7ef4\u62a4 layer \u4e0e model\uff1b\u5c3d\u53ef\u80fd\u63d0\u4f9b\u5404\u6846\u67b6\u7684 GPU \u7b97\u5b50\u3002  \n\n### \u5f53\u524d\u652f\u6301\n| \u7b97\u5b50\u7c7b\u578b | \u6846\u67b6\u652f\u6301 |\n|----------|----------|\n| GPU \u7b97\u5b50 | PyTorch\u3001JAX\uff08TensorFlow \u5f85 Google \u652f\u6301 Triton \u540e\u4e0a\u7ebf\uff09 |\n| \u539f\u751f\u7b97\u5b50 | PyTorch\u3001JAX\u3001TensorFlow\u3001NumPy |\n\n> \u672a\u6765\u82e5 Keras \u751f\u6001\u6269\u5c55\uff0c\u53ef\u80fd\u652f\u6301 MLX\u3001OpenVINO\u3002  \n> \u6ce8\u610f\uff1a\u672c\u5e93\u4f9d\u8d56 `keras`\u3002\n\n---\n\n## \u5b89\u88c5\n\n```bash\npip install rwkv_ops\n```\n\n---\n\n## \u73af\u5883\u53d8\u91cf\n\n| \u53d8\u91cf\u540d | \u542b\u4e49 | \u53d6\u503c | \u9ed8\u8ba4\u503c | \u4f18\u5148\u7ea7 |\n|---|---|---|---|---|\n| `KERAS_BACKEND` | Keras \u540e\u7aef | `jax` / `torch` / `tensorflow` / `numpy` | \u2014 | \u4f4e |\n| `KERNEL_BACKEND` | \u7b97\u5b50\u540e\u7aef | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **\u9ad8** |\n| `KERNEL_TYPE` | \u5b9e\u73b0\u7c7b\u578b | `triton` / `cuda` / `native` | \u2014 | \u2014 |\n\n> \u82e5 `KERNEL_BACKEND` \u6709\u503c\uff0c\u76f4\u63a5\u91c7\u7528\uff1b\u82e5\u4e3a\u7a7a\uff0c\u5219\u7528 `KERAS_BACKEND`\uff1b\u4e24\u8005\u7686\u7a7a\u5219\u9ed8\u8ba4 `torch`\u3002  \n> `native` \u4e3a\u539f\u751f\u7b97\u5b50\uff0c\u65e0 chunkwise\uff0c\u901f\u5ea6\u6162\u4e14\u663e\u5b58\u9ad8\u3002\n\n---\n\n## rwkv7op \u4f7f\u7528\u65b9\u6cd5\n\n```python\nfrom rwkv_ops import generalized_delta_rule  # \u6216 from rwkv_ops import rwkv7_op\uff0c\u5b8c\u5168\u7b49\u4ef7\n\ndef generalized_delta_rule(\n    r,\n    w,\n    k,\n    v,\n    a,\n    b,\n    initial_state=None,\n    output_final_state: bool = True,\n    head_first: bool = False,\n):\n    \"\"\"\n    \u5206\u5757 Delta Rule \u6ce8\u610f\u529b\u63a5\u53e3\u3002\n\n    Args:\n        q:  [B, T, H, K]\n        k:  [B, T, H, K]\n        v:  [B, T, H, V]\n        a:  [B, T, H, K]\n        b:  [B, T, H, K]\n        gk: [B, T, H, K]  # decay term in log space!\n        initial_state: \u521d\u59cb\u72b6\u6001 [N, H, K, V]\uff0cN \u4e3a\u5e8f\u5217\u6570\n        output_final_state: \u662f\u5426\u8fd4\u56de\u6700\u7ec8\u72b6\u6001\n        head_first: \u662f\u5426 head-first \u683c\u5f0f\uff0c\u4e0d\u652f\u6301\u53d8\u957f\n\n    Returns:\n        o:           \u8f93\u51fa [B, T, H, V] \u6216 [B, H, T, V]\n        final_state: \u6700\u7ec8\u72b6\u6001 [N, H, K, V] \u6216 None\n    \"\"\"\n```\n\n### torch-cuda \u7279\u6b8a\u7528\u6cd5\n\n- torch-cuda \u4e0b `head_size` \u4e5f\u662f\u4e00\u4e2a kernel \u53c2\u6570\uff0c\u9ed8\u8ba4\u4e3a 64\u3002  \n- \u82e5 `head_size \u2260 64`\uff0c\u8bf7\u4f7f\u7528\uff1a\n\n```python\nfrom rwkv_ops import get_generalized_delta_rule\n\ngeneralized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(\n    your_head_size, KERNEL_TYPE=\"cuda\"\n)\n```\n\n- `RWKV7_USE_KERNEL` \u4e3a\u5e38\u91cf\uff0c\u6807\u8bb0\u662f\u5426\u4f7f\u7528 chunkwise \u7b97\u5b50\u3002  \n- \u4e24\u8005 padding \u5904\u7406\u903b\u8f91\u4e0d\u540c\uff1a\n\n```python\nif padding_mask is not None:\n    if RWKV7_USE_KERNEL:\n        w += (1 - padding_mask) * -1e9\n    else:\n        w = w * padding_mask + 1 - padding_mask\n```\n\n---\n\n### rwkv7op \u5b9e\u73b0\u72b6\u6001\n\n| Framework   | cuda | triton | native |\n|-------------|------|--------|--------|\n| PyTorch     | \u2705   | \u2705     | \u2705     |\n| JAX         | \u274c   | \u2705     | \u2705     |\n| TensorFlow  | \u274c   | \u274c     | \u2705     |\n| NumPy       | \u274c   | \u274c     | \u2705     |\n\n---\n\n## rwkv6op \u4f7f\u7528\u65b9\u6cd5\n\n### PyTorch \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u5b89\u88c5\u4f9d\u8d56\uff1a`keras`\u3001`ninja`\u3001\u5b8c\u6574\u7684 CUDA \u5de5\u5177\u5305\u3002\n- \u82e5\u4f7f\u7528 VS Code + \u865a\u62df\u73af\u5883\u8c03\u8bd5\uff0c\u8bf7\u52a1\u5fc5\u5728\u7ec8\u7aef\u624b\u52a8\u6fc0\u6d3b\u865a\u62df\u73af\u5883\uff0c\u518d\u8fd0\u884c\u4ee3\u7801\uff0c\u5426\u5219 ninja \u53ef\u80fd\u65e0\u6cd5\u5de5\u4f5c\u3002\n- \u867d\u7136 PyTorch \u5728\u300c\u865a\u62df\u73af\u5883\u4e2d\u7684 CUDA \u7248\u672c\u300d\u4e0e\u300c\u5168\u5c40 CUDA \u7248\u672c\u300d\u4e0d\u4e00\u81f4\u65f6\u4ecd\u53ef\u6b63\u5e38\u8fd0\u884c\uff0c\u4f46\u5f3a\u70c8\u5efa\u8bae\u4fdd\u6301\u4e00\u81f4\u3002\n- PyTorch \u9650\u5236\uff1a\u540c\u4e00\u7a0b\u5e8f\u5185\u53ea\u80fd\u5b9e\u4f8b\u5316 **\u4e00\u4e2a** `RWKV6_OP` \u5bf9\u8c61\uff1b\u7b97\u5b50\u7ebf\u7a0b\u5b89\u5168\uff08\u65e0\u72b6\u6001\uff09\uff0c\u53ef\u5728\u591a\u5904\u8c03\u7528\u3002\n\n### JAX \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u5b89\u88c5\u4f9d\u8d56\uff1a`keras`\u3001`gcc`\u3001`pybind11`\u3001\u5b8c\u6574\u7684 CUDA \u5de5\u5177\u5305\u3002\n- \u5373\u4f7f\u901a\u8fc7\u865a\u62df\u73af\u5883\u4e3a JAX \u5b89\u88c5 CUDA\uff0c\u4e5f\u5fc5\u987b\u5728\u7cfb\u7edf\u7ea7\u5b89\u88c5\u5b8c\u6574 CUDA\uff1b\u4e24\u8005\u7248\u672c\u9700\u4e00\u81f4\uff0c\u4ee5\u4fdd\u8bc1 JAX \u5e76\u884c\u7f16\u8bd1\u901f\u5ea6\u3002\n- JAX \u7f16\u8bd1\u4f9d\u8d56 `/usr/local/cuda` \u8f6f\u94fe\u63a5\uff0c\u5982\u4e0d\u5b58\u5728\u8bf7\u624b\u52a8\u521b\u5efa\uff1a\n  ```shell\n  sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda\n  ```\n- \u786e\u4fdd `nvcc -V` \u6b63\u5e38\u8f93\u51fa\uff0c\u4e14 `which nvcc` \u6307\u5411\u6b63\u786e\u7248\u672c\u3002\n- JAX \u9650\u5236\uff1a\u540c\u4e00\u7a0b\u5e8f\u5185\u53ea\u80fd\u5b9e\u4f8b\u5316 **\u4e00\u4e2a** `RWKV6_OP` \u5bf9\u8c61\uff1b\u7b97\u5b50\u7ebf\u7a0b\u5b89\u5168\uff08\u65e0\u72b6\u6001\uff09\uff0c\u53ef\u5728\u591a\u5904\u8c03\u7528\u3002\n- JAX \u2265 0.6.0 \u4e0d\u518d\u4f7f\u7528 CUDA \u7b97\u5b50\uff0c\u9ed8\u8ba4\u4f7f\u7528\u539f\u751f\u7b97\u5b50\uff1b\u63a8\u8350 0.4.34\u3002\n\n### TensorFlow \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u4ec5\u63d0\u4f9b\u57fa\u4e8e\u539f\u751f API \u7684 `RWKV6` \u7b97\u5b50\uff0c\u4ec5\u7528\u4e8e\u63a8\u7406\uff0c\u6548\u7387\u8f83\u4f4e\u3002\n\n---\n\n### \u4f7f\u7528\u65b9\u6cd5\n\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u548crwkv7\u5199\u6210\u51fd\u6570\u7684\u5f62\u5f0f\u4e0d\u4e00\u6837\uff0cRWKV6\u7684op\u662f\u4e00\u4e2a\u7c7b\uff0c\u9700\u8981\u5b9e\u4f8b\u5316\u3002\n```python\nfrom rwkv_ops import RWKV6_OP\n\noperator = RWKV6_OP(\n    head_size=64,               # \u5934\u5927\u5c0f\uff0c\u4e0d\u786e\u5b9a\u65f6\u586b 64\n    max_sequence_length=4096,   # \u8bad\u7ec3\u6700\u5927\u5e8f\u5217\u957f\u5ea6\uff1b\u63a8\u7406\u4e0d\u53d7\u9650\n    ops_loop=False              # \u53ef\u9009\uff1a\u5e8f\u5217\u957f\u5ea6=1 \u65f6\u662f\u5426\u7528\u4e0a\u5c42 API \u66ff\u4ee3 CUDA\n)\n```\n\n#### \u8c03\u7528\n\n```python\ny, y_state = operator(\n    r, k, v, w, u,\n    with_state=False,   # \u662f\u5426\u4f7f\u7528\u81ea\u5b9a\u4e49\u521d\u59cb\u72b6\u6001 / \u8f93\u51fa\u7ed3\u675f\u72b6\u6001\n    init_state=None,    # \u521d\u59cb\u72b6\u6001 [n_state, num_heads, head_size, head_size]\n    state_map=None      # int32 \u4e00\u7ef4\u6570\u7ec4\uff0c\u957f\u5ea6=batch_size\uff0c\u5b9a\u4e49 init_state \u6620\u5c04\n)\n```\n\n| \u53c2\u6570 | \u5f62\u72b6 | \u8bf4\u660e |\n|---|---|---|\n| r, k, v, w | (batch_size, seq_len, hidden_size) | \u2014 |\n| u | (num_heads, head_size) \u6216 (hidden_size,) | \u2014 |\n| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 \u65f6\u6240\u6709\u6837\u672c\u5171\u7528\uff1bn_state=batch_size \u65f6\u4e00\u4e00\u5bf9\u5e94 |\n| state_map | (batch_size,) | \u6307\u5b9a\u6bcf\u4e2a\u6837\u672c\u7528\u5230\u7684 init_state \u7d22\u5f15 |\n\n| \u8fd4\u56de\u503c | \u5f62\u72b6 | \u8bf4\u660e |\n|---|---|---|\n| y | (batch_size, seq_len, hidden_size) | \u8f93\u51fa |\n| y_state | (batch_size, num_heads, head_size, head_size) \u6216 None | \u7ed3\u675f\u72b6\u6001 |\n\n---\n\n### \u5206\u5e03\u5f0f\u5c0f\u8d34\u58eb\n\n- \u7b97\u5b50\u672c\u8eab\u65e0\u5206\u5e03\u5f0f\u652f\u6301\uff1bPyTorch \u53ef\u76f4\u63a5\u7528\u591a\u7ebf\u7a0b\u5206\u5e03\u5f0f\u3002\n- JAX \u9700\u901a\u8fc7 `shard_map` \u5305\u88c5\uff08\u793a\u4f8b\uff09\uff1a\n\n```python\nimport os\nos.environ['KERAS_BACKEND'] = 'jax'\n\nimport jax, jax.numpy as jnp\nfrom jax.experimental.shard_map import shard_map\nfrom jax.sharding import Mesh, PartitionSpec as P\nfrom functools import partial\nfrom rwkv_ops import RWKV6_OP\n\nbatch_size, seq_length = 24, 512\nhead_size, num_heads = 64, 32\nhidden_size = head_size * num_heads\n\nmesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))\ndevice_ns = NamedSharding(mesh, P('device_axis'))\n\noperator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)\n\n@partial(shard_map,\n         mesh=mesh,\n         in_specs=(P('device_axis'),) * 5,\n         out_specs=(P('device_axis'), P('device_axis')),\n         check_rep=False)\ndef call_kernel(r, k, v, w, u):\n    # \u53bb\u6389\u6700\u5916 device \u7ef4\u5ea6\n    r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))\n    y, ys = operator(r, k, v, w, u, with_state=True)\n    return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)\n\n# \u6784\u9020\u8f93\u5165\u5e76\u653e\u7f6e\u5230\u5bf9\u5e94\u8bbe\u5907\nkeys = jax.random.split(jax.random.PRNGKey(0), 5)\ninputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]\ninputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(\n    lambda x: jax.device_put(x, device_ns), inputs)\ninputs_u = inputs_u[:, :, 0]  # (devices, hidden_size)\n\n# \u53ef\u9009\uff1ajax.jit(call_kernel, ...) \u52a0\u901f\noutputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)\n\nprint(outputs_y.shape, outputs_y.sharding)\nprint(y_state.shape, y_state.sharding)\n```\n\n---\n\n### rwkv6op \u5b9e\u73b0\u72b6\u6001\n\n| Framework   | cuda | triton | native |\n|-------------|------|--------|--------|\n| PyTorch     | \u2705   | \u274c     | \u2705     |\n| JAX         | \u26a0\ufe0f   | \u274c     | \u2705     |\n| TensorFlow  | \u274c   | \u274c     | \u2705     |\n| NumPy       | \u274c   | \u274c     | \u2705     |\n\n\u26a0\ufe0f JAX \u7684 CUDA \u5b9e\u73b0\u4ec5\u9002\u7528\u4e8e < 0.6.0\uff0c\u63a8\u8350 0.4.34\u3002\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": null,
    "version": "0.2",
    "project_urls": {
        "Homepage": "https://github.com/pass-lin/rwkv_ops"
    },
    "split_keywords": [
        "rwkv",
        "implement",
        "for",
        "multi",
        "backend"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "02fbdc592d4285d47ea89a3958621d5f6bde96228a3487f39421a16283b75e46",
                "md5": "7401d5fc630ed820ffc24f51d1e57aab",
                "sha256": "ec5faa8cc8267248d4fbe47d73ec5346560d78d49f82d50e4911dd12a920bde5"
            },
            "downloads": -1,
            "filename": "rwkv_ops-0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "7401d5fc630ed820ffc24f51d1e57aab",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 66117,
            "upload_time": "2025-07-13T03:05:18",
            "upload_time_iso_8601": "2025-07-13T03:05:18.192972Z",
            "url": "https://files.pythonhosted.org/packages/02/fb/dc592d4285d47ea89a3958621d5f6bde96228a3487f39421a16283b75e46/rwkv_ops-0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "0ef682f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122",
                "md5": "18074d9960d156f71227abdb859a9e42",
                "sha256": "152149234b318cafa2eac019535e3c029f791f34975a21df93100edde549bce4"
            },
            "downloads": -1,
            "filename": "rwkv_ops-0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "18074d9960d156f71227abdb859a9e42",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 45873,
            "upload_time": "2025-07-13T03:05:19",
            "upload_time_iso_8601": "2025-07-13T03:05:19.743594Z",
            "url": "https://files.pythonhosted.org/packages/0e/f6/82f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122/rwkv_ops-0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-13 03:05:19",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "pass-lin",
    "github_project": "rwkv_ops",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "rwkv-ops"
}
        
Elapsed time: 0.56245s