# Triton-Optimized Cross-Entropy Kernel
A high-performance, memory-efficient cross-entropy loss implementation using [Triton](https://github.com/openai/triton) for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.
> **Attribution:**
> This implementation is adapted from [Unsloth's cross-entropy kernel](https://github.com/unslothai/unsloth/blob/1898b6d049d606ec88f3f9307172373776eec0f6/unsloth/kernels/cross_entropy_loss.py).
---
## Features
- **Memory Efficient:** Fused kernel reduces memory footprint.
- **High Performance:** Optimized for large vocabulary sizes with Triton JIT.
- **Causal LM Compatible:** Handles shifted logits/labels for autoregressive language modeling.
- **Ignore Index Support:** Configurable ignore index for masking tokens (default: `-100`).
- **CUDA Accelerated:** Fully utilizes CUDA GPUs for maximum throughput.
- **Autograd Compatible:** Exposes a PyTorch-compatible `autograd.Function` and `nn.Module`.
---
## Requirements
- PyTorch (CUDA-enabled)
- Triton
- CUDA-compatible GPU
---
## Installation
Install from PyPI:
```bash
pip install crossentropy-triton
```
Or install with specific PyTorch/Triton versions:
```bash
pip install crossentropy-triton torch triton
```
---
## Usage
### Basic Usage (Autograd Function)
```python
import torch
from crossentropy_triton import CrossEntropyFunction
device = torch.device('cuda')
# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)
# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")
# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")
```
### Using the Causal LM Loss Module
```python
import torch
from crossentropy_triton import TritonCausalLMLoss
device = torch.device('cuda')
vocab_size = 32000
# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)
# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)
# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")
loss.backward()
print(f"Backward pass completed")
```
---
## Performance Characteristics
- **Optimized Block Size:** Chooses optimal kernel block sizes up to 32,768.
- **Memory Fusion:** Fuses softmax and gradient computation in a single kernel.
- **Efficient Masking:** Ignore index is handled directly in the kernel.
- **Gradient Scaling:** Proper normalization by non-ignored tokens.
---
## Technical Details
### Kernel Implementation
- **`cross_entropy_kernel`:** Computes the forward pass (loss) and gradients in the logits tensor.
- **`element_mul_kernel`:** Scales in-place gradients by gradient outputs during backward.
### Memory and Numerical Stability
- Supports both contiguous and non-contiguous tensors.
- In-place gradient computation for minimal overhead.
- Log-sum-exp trick for stable softmax.
### Shifted Sequence Handling
- Causal/auto-regressive shifts are built in for next-token prediction.
---
## License
MIT License
Raw data
{
"_id": null,
"home_page": "https://github.com/Dcas89/crossentropy-triton",
"name": "crossentropy-triton",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "triton, cuda, cross-entropy, machine-learning, deep-learning, pytorch",
"author": "Daniel Castillo",
"author_email": "d.castillocastagneto@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/b7/72/e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5/crossentropy_triton-0.1.2.tar.gz",
"platform": null,
"description": "# Triton-Optimized Cross-Entropy Kernel\n\nA high-performance, memory-efficient cross-entropy loss implementation using [Triton](https://github.com/openai/triton) for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.\n\n> **Attribution:** \n> This implementation is adapted from [Unsloth's cross-entropy kernel](https://github.com/unslothai/unsloth/blob/1898b6d049d606ec88f3f9307172373776eec0f6/unsloth/kernels/cross_entropy_loss.py).\n\n---\n\n## Features\n\n- **Memory Efficient:** Fused kernel reduces memory footprint.\n- **High Performance:** Optimized for large vocabulary sizes with Triton JIT.\n- **Causal LM Compatible:** Handles shifted logits/labels for autoregressive language modeling.\n- **Ignore Index Support:** Configurable ignore index for masking tokens (default: `-100`).\n- **CUDA Accelerated:** Fully utilizes CUDA GPUs for maximum throughput.\n- **Autograd Compatible:** Exposes a PyTorch-compatible `autograd.Function` and `nn.Module`.\n\n---\n\n## Requirements\n\n- PyTorch (CUDA-enabled)\n- Triton\n- CUDA-compatible GPU\n\n---\n\n## Installation\n\nInstall from PyPI:\n\n```bash\npip install crossentropy-triton\n```\n\nOr install with specific PyTorch/Triton versions:\n\n```bash\npip install crossentropy-triton torch triton\n```\n\n---\n\n## Usage\n\n### Basic Usage (Autograd Function)\n\n```python\nimport torch\nfrom crossentropy_triton import CrossEntropyFunction\n\ndevice = torch.device('cuda')\n\n# Create sample data [batch, seq, vocab_size]\nlogits = torch.randn(2, 10, 32000, device=device, requires_grad=True)\nlabels = torch.randint(0, 32000, (2, 10), device=device)\n\n# Forward pass with ignore_index=-100 (default for masked tokens)\nloss = CrossEntropyFunction.apply(logits, labels, -100)\nprint(f\"Loss: {loss.item():.4f}\")\n\n# Backward pass\nloss.backward()\nprint(f\"Gradients computed - shape: {logits.grad.shape}\")\n```\n\n### Using the Causal LM Loss Module\n\n```python\nimport torch\nfrom crossentropy_triton import TritonCausalLMLoss\n\ndevice = torch.device('cuda')\nvocab_size = 32000\n\n# Initialize the loss function\nloss_fn = TritonCausalLMLoss(vocab_size)\n\n# Create sample data\nlogits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)\nlabels = torch.randint(0, vocab_size, (2, 10), device=device)\n\n# Forward and backward pass\nloss = loss_fn(logits, labels)\nprint(f\"Causal LM loss: {loss.item():.4f}\")\n\nloss.backward()\nprint(f\"Backward pass completed\")\n```\n\n---\n\n## Performance Characteristics\n\n- **Optimized Block Size:** Chooses optimal kernel block sizes up to 32,768.\n- **Memory Fusion:** Fuses softmax and gradient computation in a single kernel.\n- **Efficient Masking:** Ignore index is handled directly in the kernel.\n- **Gradient Scaling:** Proper normalization by non-ignored tokens.\n\n---\n\n## Technical Details\n\n### Kernel Implementation\n\n- **`cross_entropy_kernel`:** Computes the forward pass (loss) and gradients in the logits tensor.\n- **`element_mul_kernel`:** Scales in-place gradients by gradient outputs during backward.\n\n### Memory and Numerical Stability\n\n- Supports both contiguous and non-contiguous tensors.\n- In-place gradient computation for minimal overhead.\n- Log-sum-exp trick for stable softmax.\n\n### Shifted Sequence Handling\n\n- Causal/auto-regressive shifts are built in for next-token prediction.\n\n---\n\n## License\n\nMIT License\n",
"bugtrack_url": null,
"license": null,
"summary": "A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs",
"version": "0.1.2",
"project_urls": {
"Bug Reports": "https://github.com/Dcas89/crossentropy-triton/issues",
"Homepage": "https://github.com/Dcas89/crossentropy-triton",
"Source": "https://github.com/Dcas89/crossentropy-triton"
},
"split_keywords": [
"triton",
" cuda",
" cross-entropy",
" machine-learning",
" deep-learning",
" pytorch"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "5079adaaa7b32ababaa689d8056bdcade41748ae99cf35e88d4010f8214bf3ea",
"md5": "a7517b10e4eac1d1f6b77a60495a4f9d",
"sha256": "58631daa6ad30d6d37b19afc71e40c0debe3967c87d429a4f4a9a3418dd5bccc"
},
"downloads": -1,
"filename": "crossentropy_triton-0.1.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "a7517b10e4eac1d1f6b77a60495a4f9d",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 6497,
"upload_time": "2025-07-13T22:39:49",
"upload_time_iso_8601": "2025-07-13T22:39:49.993416Z",
"url": "https://files.pythonhosted.org/packages/50/79/adaaa7b32ababaa689d8056bdcade41748ae99cf35e88d4010f8214bf3ea/crossentropy_triton-0.1.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "b772e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5",
"md5": "f5c4cdc50531d222aa39b31cb93dab90",
"sha256": "7bbd4592b7d9b2add38906607ac54ca396b15f9f1d64d835783e8aea0000ce50"
},
"downloads": -1,
"filename": "crossentropy_triton-0.1.2.tar.gz",
"has_sig": false,
"md5_digest": "f5c4cdc50531d222aa39b31cb93dab90",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 7395,
"upload_time": "2025-07-13T22:39:51",
"upload_time_iso_8601": "2025-07-13T22:39:51.227293Z",
"url": "https://files.pythonhosted.org/packages/b7/72/e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5/crossentropy_triton-0.1.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-13 22:39:51",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "Dcas89",
"github_project": "crossentropy-triton",
"github_not_found": true,
"lcname": "crossentropy-triton"
}