[English Document](ENREADME.md)
# RWKV OPS 项目
> 由于 RWKV 将持续迭代,核心算子会随之更新。
> 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
### 当前支持
| 算子类型 | 框架支持 |
|----------|----------|
| GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
| 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
> 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
> 注意:本库依赖 `keras`。
---
## 安装
```bash
pip install rwkv_ops
```
---
## 环境变量
| 变量名 | 含义 | 取值 | 默认值 | 优先级 |
|---|---|---|---|---|
| `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
| `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | — | — |
> 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
---
## rwkv7op 使用方法
```python
from rwkv_ops import generalized_delta_rule # 或 from rwkv_ops import rwkv7_op,完全等价
def generalized_delta_rule(
r,
w,
k,
v,
a,
b,
initial_state=None,
output_final_state: bool = True,
head_first: bool = False,
):
"""
分块 Delta Rule 注意力接口。
Args:
q: [B, T, H, K]
k: [B, T, H, K]
v: [B, T, H, V]
a: [B, T, H, K]
b: [B, T, H, K]
gk: [B, T, H, K] # decay term in log space!
initial_state: 初始状态 [N, H, K, V],N 为序列数
output_final_state: 是否返回最终状态
head_first: 是否 head-first 格式,不支持变长
Returns:
o: 输出 [B, T, H, V] 或 [B, H, T, V]
final_state: 最终状态 [N, H, K, V] 或 None
"""
```
### torch-cuda 特殊用法
- torch-cuda 下 `head_size` 也是一个 kernel 参数,默认为 64。
- 若 `head_size ≠ 64`,请使用:
```python
from rwkv_ops import get_generalized_delta_rule
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
your_head_size, KERNEL_TYPE="cuda"
)
```
- `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。
- 两者 padding 处理逻辑不同:
```python
if padding_mask is not None:
if RWKV7_USE_KERNEL:
w += (1 - padding_mask) * -1e9
else:
w = w * padding_mask + 1 - padding_mask
```
---
### rwkv7op 实现状态
| Framework | cuda | triton | native |
|-------------|------|--------|--------|
| PyTorch | ✅ | ✅ | ✅ |
| JAX | ❌ | ✅ | ✅ |
| TensorFlow | ❌ | ❌ | ✅ |
| NumPy | ❌ | ❌ | ✅ |
---
## rwkv6op 使用方法
### PyTorch 使用注意事项
- 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
- 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
- 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
- PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
### JAX 使用注意事项
- 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
- 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
- JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
```shell
sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
```
- 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
- JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
- JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。
### TensorFlow 使用注意事项
- 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。
---
### 使用方法
需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
```python
from rwkv_ops import RWKV6_OP
operator = RWKV6_OP(
head_size=64, # 头大小,不确定时填 64
max_sequence_length=4096, # 训练最大序列长度;推理不受限
ops_loop=False # 可选:序列长度=1 时是否用上层 API 替代 CUDA
)
```
#### 调用
```python
y, y_state = operator(
r, k, v, w, u,
with_state=False, # 是否使用自定义初始状态 / 输出结束状态
init_state=None, # 初始状态 [n_state, num_heads, head_size, head_size]
state_map=None # int32 一维数组,长度=batch_size,定义 init_state 映射
)
```
| 参数 | 形状 | 说明 |
|---|---|---|
| r, k, v, w | (batch_size, seq_len, hidden_size) | — |
| u | (num_heads, head_size) 或 (hidden_size,) | — |
| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
| state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |
| 返回值 | 形状 | 说明 |
|---|---|---|
| y | (batch_size, seq_len, hidden_size) | 输出 |
| y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |
---
### 分布式小贴士
- 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
- JAX 需通过 `shard_map` 包装(示例):
```python
import os
os.environ['KERAS_BACKEND'] = 'jax'
import jax, jax.numpy as jnp
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
from functools import partial
from rwkv_ops import RWKV6_OP
batch_size, seq_length = 24, 512
head_size, num_heads = 64, 32
hidden_size = head_size * num_heads
mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
device_ns = NamedSharding(mesh, P('device_axis'))
operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)
@partial(shard_map,
mesh=mesh,
in_specs=(P('device_axis'),) * 5,
out_specs=(P('device_axis'), P('device_axis')),
check_rep=False)
def call_kernel(r, k, v, w, u):
# 去掉最外 device 维度
r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
y, ys = operator(r, k, v, w, u, with_state=True)
return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)
# 构造输入并放置到对应设备
keys = jax.random.split(jax.random.PRNGKey(0), 5)
inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
lambda x: jax.device_put(x, device_ns), inputs)
inputs_u = inputs_u[:, :, 0] # (devices, hidden_size)
# 可选:jax.jit(call_kernel, ...) 加速
outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)
print(outputs_y.shape, outputs_y.sharding)
print(y_state.shape, y_state.sharding)
```
---
### rwkv6op 实现状态
| Framework | cuda | triton | native |
|-------------|------|--------|--------|
| PyTorch | ✅ | ❌ | ✅ |
| JAX | ⚠️ | ❌ | ✅ |
| TensorFlow | ❌ | ❌ | ✅ |
| NumPy | ❌ | ❌ | ✅ |
⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。
Raw data
{
"_id": null,
"home_page": "https://github.com/pass-lin/rwkv_ops",
"name": "rwkv-ops",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "rwkv implement for multi backend",
"author": null,
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/0e/f6/82f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122/rwkv_ops-0.2.tar.gz",
"platform": null,
"description": "[English Document](ENREADME.md)\n\n# RWKV OPS \u9879\u76ee\n\n> \u7531\u4e8e RWKV \u5c06\u6301\u7eed\u8fed\u4ee3\uff0c\u6838\u5fc3\u7b97\u5b50\u4f1a\u968f\u4e4b\u66f4\u65b0\u3002 \n> \u672c\u4ed3\u4e13\u95e8\u7ef4\u62a4\u300c\u7b97\u5b50\u300d\u672c\u8eab\uff0c\u4e0d\u7ef4\u62a4 layer \u4e0e model\uff1b\u5c3d\u53ef\u80fd\u63d0\u4f9b\u5404\u6846\u67b6\u7684 GPU \u7b97\u5b50\u3002 \n\n### \u5f53\u524d\u652f\u6301\n| \u7b97\u5b50\u7c7b\u578b | \u6846\u67b6\u652f\u6301 |\n|----------|----------|\n| GPU \u7b97\u5b50 | PyTorch\u3001JAX\uff08TensorFlow \u5f85 Google \u652f\u6301 Triton \u540e\u4e0a\u7ebf\uff09 |\n| \u539f\u751f\u7b97\u5b50 | PyTorch\u3001JAX\u3001TensorFlow\u3001NumPy |\n\n> \u672a\u6765\u82e5 Keras \u751f\u6001\u6269\u5c55\uff0c\u53ef\u80fd\u652f\u6301 MLX\u3001OpenVINO\u3002 \n> \u6ce8\u610f\uff1a\u672c\u5e93\u4f9d\u8d56 `keras`\u3002\n\n---\n\n## \u5b89\u88c5\n\n```bash\npip install rwkv_ops\n```\n\n---\n\n## \u73af\u5883\u53d8\u91cf\n\n| \u53d8\u91cf\u540d | \u542b\u4e49 | \u53d6\u503c | \u9ed8\u8ba4\u503c | \u4f18\u5148\u7ea7 |\n|---|---|---|---|---|\n| `KERAS_BACKEND` | Keras \u540e\u7aef | `jax` / `torch` / `tensorflow` / `numpy` | \u2014 | \u4f4e |\n| `KERNEL_BACKEND` | \u7b97\u5b50\u540e\u7aef | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **\u9ad8** |\n| `KERNEL_TYPE` | \u5b9e\u73b0\u7c7b\u578b | `triton` / `cuda` / `native` | \u2014 | \u2014 |\n\n> \u82e5 `KERNEL_BACKEND` \u6709\u503c\uff0c\u76f4\u63a5\u91c7\u7528\uff1b\u82e5\u4e3a\u7a7a\uff0c\u5219\u7528 `KERAS_BACKEND`\uff1b\u4e24\u8005\u7686\u7a7a\u5219\u9ed8\u8ba4 `torch`\u3002 \n> `native` \u4e3a\u539f\u751f\u7b97\u5b50\uff0c\u65e0 chunkwise\uff0c\u901f\u5ea6\u6162\u4e14\u663e\u5b58\u9ad8\u3002\n\n---\n\n## rwkv7op \u4f7f\u7528\u65b9\u6cd5\n\n```python\nfrom rwkv_ops import generalized_delta_rule # \u6216 from rwkv_ops import rwkv7_op\uff0c\u5b8c\u5168\u7b49\u4ef7\n\ndef generalized_delta_rule(\n r,\n w,\n k,\n v,\n a,\n b,\n initial_state=None,\n output_final_state: bool = True,\n head_first: bool = False,\n):\n \"\"\"\n \u5206\u5757 Delta Rule \u6ce8\u610f\u529b\u63a5\u53e3\u3002\n\n Args:\n q: [B, T, H, K]\n k: [B, T, H, K]\n v: [B, T, H, V]\n a: [B, T, H, K]\n b: [B, T, H, K]\n gk: [B, T, H, K] # decay term in log space!\n initial_state: \u521d\u59cb\u72b6\u6001 [N, H, K, V]\uff0cN \u4e3a\u5e8f\u5217\u6570\n output_final_state: \u662f\u5426\u8fd4\u56de\u6700\u7ec8\u72b6\u6001\n head_first: \u662f\u5426 head-first \u683c\u5f0f\uff0c\u4e0d\u652f\u6301\u53d8\u957f\n\n Returns:\n o: \u8f93\u51fa [B, T, H, V] \u6216 [B, H, T, V]\n final_state: \u6700\u7ec8\u72b6\u6001 [N, H, K, V] \u6216 None\n \"\"\"\n```\n\n### torch-cuda \u7279\u6b8a\u7528\u6cd5\n\n- torch-cuda \u4e0b `head_size` \u4e5f\u662f\u4e00\u4e2a kernel \u53c2\u6570\uff0c\u9ed8\u8ba4\u4e3a 64\u3002 \n- \u82e5 `head_size \u2260 64`\uff0c\u8bf7\u4f7f\u7528\uff1a\n\n```python\nfrom rwkv_ops import get_generalized_delta_rule\n\ngeneralized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(\n your_head_size, KERNEL_TYPE=\"cuda\"\n)\n```\n\n- `RWKV7_USE_KERNEL` \u4e3a\u5e38\u91cf\uff0c\u6807\u8bb0\u662f\u5426\u4f7f\u7528 chunkwise \u7b97\u5b50\u3002 \n- \u4e24\u8005 padding \u5904\u7406\u903b\u8f91\u4e0d\u540c\uff1a\n\n```python\nif padding_mask is not None:\n if RWKV7_USE_KERNEL:\n w += (1 - padding_mask) * -1e9\n else:\n w = w * padding_mask + 1 - padding_mask\n```\n\n---\n\n### rwkv7op \u5b9e\u73b0\u72b6\u6001\n\n| Framework | cuda | triton | native |\n|-------------|------|--------|--------|\n| PyTorch | \u2705 | \u2705 | \u2705 |\n| JAX | \u274c | \u2705 | \u2705 |\n| TensorFlow | \u274c | \u274c | \u2705 |\n| NumPy | \u274c | \u274c | \u2705 |\n\n---\n\n## rwkv6op \u4f7f\u7528\u65b9\u6cd5\n\n### PyTorch \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u5b89\u88c5\u4f9d\u8d56\uff1a`keras`\u3001`ninja`\u3001\u5b8c\u6574\u7684 CUDA \u5de5\u5177\u5305\u3002\n- \u82e5\u4f7f\u7528 VS Code + \u865a\u62df\u73af\u5883\u8c03\u8bd5\uff0c\u8bf7\u52a1\u5fc5\u5728\u7ec8\u7aef\u624b\u52a8\u6fc0\u6d3b\u865a\u62df\u73af\u5883\uff0c\u518d\u8fd0\u884c\u4ee3\u7801\uff0c\u5426\u5219 ninja \u53ef\u80fd\u65e0\u6cd5\u5de5\u4f5c\u3002\n- \u867d\u7136 PyTorch \u5728\u300c\u865a\u62df\u73af\u5883\u4e2d\u7684 CUDA \u7248\u672c\u300d\u4e0e\u300c\u5168\u5c40 CUDA \u7248\u672c\u300d\u4e0d\u4e00\u81f4\u65f6\u4ecd\u53ef\u6b63\u5e38\u8fd0\u884c\uff0c\u4f46\u5f3a\u70c8\u5efa\u8bae\u4fdd\u6301\u4e00\u81f4\u3002\n- PyTorch \u9650\u5236\uff1a\u540c\u4e00\u7a0b\u5e8f\u5185\u53ea\u80fd\u5b9e\u4f8b\u5316 **\u4e00\u4e2a** `RWKV6_OP` \u5bf9\u8c61\uff1b\u7b97\u5b50\u7ebf\u7a0b\u5b89\u5168\uff08\u65e0\u72b6\u6001\uff09\uff0c\u53ef\u5728\u591a\u5904\u8c03\u7528\u3002\n\n### JAX \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u5b89\u88c5\u4f9d\u8d56\uff1a`keras`\u3001`gcc`\u3001`pybind11`\u3001\u5b8c\u6574\u7684 CUDA \u5de5\u5177\u5305\u3002\n- \u5373\u4f7f\u901a\u8fc7\u865a\u62df\u73af\u5883\u4e3a JAX \u5b89\u88c5 CUDA\uff0c\u4e5f\u5fc5\u987b\u5728\u7cfb\u7edf\u7ea7\u5b89\u88c5\u5b8c\u6574 CUDA\uff1b\u4e24\u8005\u7248\u672c\u9700\u4e00\u81f4\uff0c\u4ee5\u4fdd\u8bc1 JAX \u5e76\u884c\u7f16\u8bd1\u901f\u5ea6\u3002\n- JAX \u7f16\u8bd1\u4f9d\u8d56 `/usr/local/cuda` \u8f6f\u94fe\u63a5\uff0c\u5982\u4e0d\u5b58\u5728\u8bf7\u624b\u52a8\u521b\u5efa\uff1a\n ```shell\n sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda\n ```\n- \u786e\u4fdd `nvcc -V` \u6b63\u5e38\u8f93\u51fa\uff0c\u4e14 `which nvcc` \u6307\u5411\u6b63\u786e\u7248\u672c\u3002\n- JAX \u9650\u5236\uff1a\u540c\u4e00\u7a0b\u5e8f\u5185\u53ea\u80fd\u5b9e\u4f8b\u5316 **\u4e00\u4e2a** `RWKV6_OP` \u5bf9\u8c61\uff1b\u7b97\u5b50\u7ebf\u7a0b\u5b89\u5168\uff08\u65e0\u72b6\u6001\uff09\uff0c\u53ef\u5728\u591a\u5904\u8c03\u7528\u3002\n- JAX \u2265 0.6.0 \u4e0d\u518d\u4f7f\u7528 CUDA \u7b97\u5b50\uff0c\u9ed8\u8ba4\u4f7f\u7528\u539f\u751f\u7b97\u5b50\uff1b\u63a8\u8350 0.4.34\u3002\n\n### TensorFlow \u4f7f\u7528\u6ce8\u610f\u4e8b\u9879\n\n- \u4ec5\u63d0\u4f9b\u57fa\u4e8e\u539f\u751f API \u7684 `RWKV6` \u7b97\u5b50\uff0c\u4ec5\u7528\u4e8e\u63a8\u7406\uff0c\u6548\u7387\u8f83\u4f4e\u3002\n\n---\n\n### \u4f7f\u7528\u65b9\u6cd5\n\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u548crwkv7\u5199\u6210\u51fd\u6570\u7684\u5f62\u5f0f\u4e0d\u4e00\u6837\uff0cRWKV6\u7684op\u662f\u4e00\u4e2a\u7c7b\uff0c\u9700\u8981\u5b9e\u4f8b\u5316\u3002\n```python\nfrom rwkv_ops import RWKV6_OP\n\noperator = RWKV6_OP(\n head_size=64, # \u5934\u5927\u5c0f\uff0c\u4e0d\u786e\u5b9a\u65f6\u586b 64\n max_sequence_length=4096, # \u8bad\u7ec3\u6700\u5927\u5e8f\u5217\u957f\u5ea6\uff1b\u63a8\u7406\u4e0d\u53d7\u9650\n ops_loop=False # \u53ef\u9009\uff1a\u5e8f\u5217\u957f\u5ea6=1 \u65f6\u662f\u5426\u7528\u4e0a\u5c42 API \u66ff\u4ee3 CUDA\n)\n```\n\n#### \u8c03\u7528\n\n```python\ny, y_state = operator(\n r, k, v, w, u,\n with_state=False, # \u662f\u5426\u4f7f\u7528\u81ea\u5b9a\u4e49\u521d\u59cb\u72b6\u6001 / \u8f93\u51fa\u7ed3\u675f\u72b6\u6001\n init_state=None, # \u521d\u59cb\u72b6\u6001 [n_state, num_heads, head_size, head_size]\n state_map=None # int32 \u4e00\u7ef4\u6570\u7ec4\uff0c\u957f\u5ea6=batch_size\uff0c\u5b9a\u4e49 init_state \u6620\u5c04\n)\n```\n\n| \u53c2\u6570 | \u5f62\u72b6 | \u8bf4\u660e |\n|---|---|---|\n| r, k, v, w | (batch_size, seq_len, hidden_size) | \u2014 |\n| u | (num_heads, head_size) \u6216 (hidden_size,) | \u2014 |\n| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 \u65f6\u6240\u6709\u6837\u672c\u5171\u7528\uff1bn_state=batch_size \u65f6\u4e00\u4e00\u5bf9\u5e94 |\n| state_map | (batch_size,) | \u6307\u5b9a\u6bcf\u4e2a\u6837\u672c\u7528\u5230\u7684 init_state \u7d22\u5f15 |\n\n| \u8fd4\u56de\u503c | \u5f62\u72b6 | \u8bf4\u660e |\n|---|---|---|\n| y | (batch_size, seq_len, hidden_size) | \u8f93\u51fa |\n| y_state | (batch_size, num_heads, head_size, head_size) \u6216 None | \u7ed3\u675f\u72b6\u6001 |\n\n---\n\n### \u5206\u5e03\u5f0f\u5c0f\u8d34\u58eb\n\n- \u7b97\u5b50\u672c\u8eab\u65e0\u5206\u5e03\u5f0f\u652f\u6301\uff1bPyTorch \u53ef\u76f4\u63a5\u7528\u591a\u7ebf\u7a0b\u5206\u5e03\u5f0f\u3002\n- JAX \u9700\u901a\u8fc7 `shard_map` \u5305\u88c5\uff08\u793a\u4f8b\uff09\uff1a\n\n```python\nimport os\nos.environ['KERAS_BACKEND'] = 'jax'\n\nimport jax, jax.numpy as jnp\nfrom jax.experimental.shard_map import shard_map\nfrom jax.sharding import Mesh, PartitionSpec as P\nfrom functools import partial\nfrom rwkv_ops import RWKV6_OP\n\nbatch_size, seq_length = 24, 512\nhead_size, num_heads = 64, 32\nhidden_size = head_size * num_heads\n\nmesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))\ndevice_ns = NamedSharding(mesh, P('device_axis'))\n\noperator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)\n\n@partial(shard_map,\n mesh=mesh,\n in_specs=(P('device_axis'),) * 5,\n out_specs=(P('device_axis'), P('device_axis')),\n check_rep=False)\ndef call_kernel(r, k, v, w, u):\n # \u53bb\u6389\u6700\u5916 device \u7ef4\u5ea6\n r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))\n y, ys = operator(r, k, v, w, u, with_state=True)\n return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)\n\n# \u6784\u9020\u8f93\u5165\u5e76\u653e\u7f6e\u5230\u5bf9\u5e94\u8bbe\u5907\nkeys = jax.random.split(jax.random.PRNGKey(0), 5)\ninputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]\ninputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(\n lambda x: jax.device_put(x, device_ns), inputs)\ninputs_u = inputs_u[:, :, 0] # (devices, hidden_size)\n\n# \u53ef\u9009\uff1ajax.jit(call_kernel, ...) \u52a0\u901f\noutputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)\n\nprint(outputs_y.shape, outputs_y.sharding)\nprint(y_state.shape, y_state.sharding)\n```\n\n---\n\n### rwkv6op \u5b9e\u73b0\u72b6\u6001\n\n| Framework | cuda | triton | native |\n|-------------|------|--------|--------|\n| PyTorch | \u2705 | \u274c | \u2705 |\n| JAX | \u26a0\ufe0f | \u274c | \u2705 |\n| TensorFlow | \u274c | \u274c | \u2705 |\n| NumPy | \u274c | \u274c | \u2705 |\n\n\u26a0\ufe0f JAX \u7684 CUDA \u5b9e\u73b0\u4ec5\u9002\u7528\u4e8e < 0.6.0\uff0c\u63a8\u8350 0.4.34\u3002\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": null,
"version": "0.2",
"project_urls": {
"Homepage": "https://github.com/pass-lin/rwkv_ops"
},
"split_keywords": [
"rwkv",
"implement",
"for",
"multi",
"backend"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "02fbdc592d4285d47ea89a3958621d5f6bde96228a3487f39421a16283b75e46",
"md5": "7401d5fc630ed820ffc24f51d1e57aab",
"sha256": "ec5faa8cc8267248d4fbe47d73ec5346560d78d49f82d50e4911dd12a920bde5"
},
"downloads": -1,
"filename": "rwkv_ops-0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "7401d5fc630ed820ffc24f51d1e57aab",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 66117,
"upload_time": "2025-07-13T03:05:18",
"upload_time_iso_8601": "2025-07-13T03:05:18.192972Z",
"url": "https://files.pythonhosted.org/packages/02/fb/dc592d4285d47ea89a3958621d5f6bde96228a3487f39421a16283b75e46/rwkv_ops-0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "0ef682f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122",
"md5": "18074d9960d156f71227abdb859a9e42",
"sha256": "152149234b318cafa2eac019535e3c029f791f34975a21df93100edde549bce4"
},
"downloads": -1,
"filename": "rwkv_ops-0.2.tar.gz",
"has_sig": false,
"md5_digest": "18074d9960d156f71227abdb859a9e42",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 45873,
"upload_time": "2025-07-13T03:05:19",
"upload_time_iso_8601": "2025-07-13T03:05:19.743594Z",
"url": "https://files.pythonhosted.org/packages/0e/f6/82f00266d97ba2996f2fbcf846fcce384f2b88d4083bc0b501d1abaed122/rwkv_ops-0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-13 03:05:19",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "pass-lin",
"github_project": "rwkv_ops",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "rwkv-ops"
}