crossentropy-triton


Namecrossentropy-triton JSON
Version 0.1.2 PyPI version JSON
download
home_pagehttps://github.com/Dcas89/crossentropy-triton
SummaryA high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs
upload_time2025-07-13 22:39:51
maintainerNone
docs_urlNone
authorDaniel Castillo
requires_python>=3.9
licenseNone
keywords triton cuda cross-entropy machine-learning deep-learning pytorch
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Triton-Optimized Cross-Entropy Kernel

A high-performance, memory-efficient cross-entropy loss implementation using [Triton](https://github.com/openai/triton) for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.

> **Attribution:**  
> This implementation is adapted from [Unsloth's cross-entropy kernel](https://github.com/unslothai/unsloth/blob/1898b6d049d606ec88f3f9307172373776eec0f6/unsloth/kernels/cross_entropy_loss.py).

---

## Features

- **Memory Efficient:** Fused kernel reduces memory footprint.
- **High Performance:** Optimized for large vocabulary sizes with Triton JIT.
- **Causal LM Compatible:** Handles shifted logits/labels for autoregressive language modeling.
- **Ignore Index Support:** Configurable ignore index for masking tokens (default: `-100`).
- **CUDA Accelerated:** Fully utilizes CUDA GPUs for maximum throughput.
- **Autograd Compatible:** Exposes a PyTorch-compatible `autograd.Function` and `nn.Module`.

---

## Requirements

- PyTorch (CUDA-enabled)
- Triton
- CUDA-compatible GPU

---

## Installation

Install from PyPI:

```bash
pip install crossentropy-triton
```

Or install with specific PyTorch/Triton versions:

```bash
pip install crossentropy-triton torch triton
```

---

## Usage

### Basic Usage (Autograd Function)

```python
import torch
from crossentropy_triton import CrossEntropyFunction

device = torch.device('cuda')

# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)

# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")
```

### Using the Causal LM Loss Module

```python
import torch
from crossentropy_triton import TritonCausalLMLoss

device = torch.device('cuda')
vocab_size = 32000

# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)

# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)

# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")

loss.backward()
print(f"Backward pass completed")
```

---

## Performance Characteristics

- **Optimized Block Size:** Chooses optimal kernel block sizes up to 32,768.
- **Memory Fusion:** Fuses softmax and gradient computation in a single kernel.
- **Efficient Masking:** Ignore index is handled directly in the kernel.
- **Gradient Scaling:** Proper normalization by non-ignored tokens.

---

## Technical Details

### Kernel Implementation

- **`cross_entropy_kernel`:** Computes the forward pass (loss) and gradients in the logits tensor.
- **`element_mul_kernel`:** Scales in-place gradients by gradient outputs during backward.

### Memory and Numerical Stability

- Supports both contiguous and non-contiguous tensors.
- In-place gradient computation for minimal overhead.
- Log-sum-exp trick for stable softmax.

### Shifted Sequence Handling

- Causal/auto-regressive shifts are built in for next-token prediction.

---

## License

MIT License

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/Dcas89/crossentropy-triton",
    "name": "crossentropy-triton",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "triton, cuda, cross-entropy, machine-learning, deep-learning, pytorch",
    "author": "Daniel Castillo",
    "author_email": "d.castillocastagneto@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/b7/72/e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5/crossentropy_triton-0.1.2.tar.gz",
    "platform": null,
    "description": "# Triton-Optimized Cross-Entropy Kernel\n\nA high-performance, memory-efficient cross-entropy loss implementation using [Triton](https://github.com/openai/triton) for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.\n\n> **Attribution:**  \n> This implementation is adapted from [Unsloth's cross-entropy kernel](https://github.com/unslothai/unsloth/blob/1898b6d049d606ec88f3f9307172373776eec0f6/unsloth/kernels/cross_entropy_loss.py).\n\n---\n\n## Features\n\n- **Memory Efficient:** Fused kernel reduces memory footprint.\n- **High Performance:** Optimized for large vocabulary sizes with Triton JIT.\n- **Causal LM Compatible:** Handles shifted logits/labels for autoregressive language modeling.\n- **Ignore Index Support:** Configurable ignore index for masking tokens (default: `-100`).\n- **CUDA Accelerated:** Fully utilizes CUDA GPUs for maximum throughput.\n- **Autograd Compatible:** Exposes a PyTorch-compatible `autograd.Function` and `nn.Module`.\n\n---\n\n## Requirements\n\n- PyTorch (CUDA-enabled)\n- Triton\n- CUDA-compatible GPU\n\n---\n\n## Installation\n\nInstall from PyPI:\n\n```bash\npip install crossentropy-triton\n```\n\nOr install with specific PyTorch/Triton versions:\n\n```bash\npip install crossentropy-triton torch triton\n```\n\n---\n\n## Usage\n\n### Basic Usage (Autograd Function)\n\n```python\nimport torch\nfrom crossentropy_triton import CrossEntropyFunction\n\ndevice = torch.device('cuda')\n\n# Create sample data [batch, seq, vocab_size]\nlogits = torch.randn(2, 10, 32000, device=device, requires_grad=True)\nlabels = torch.randint(0, 32000, (2, 10), device=device)\n\n# Forward pass with ignore_index=-100 (default for masked tokens)\nloss = CrossEntropyFunction.apply(logits, labels, -100)\nprint(f\"Loss: {loss.item():.4f}\")\n\n# Backward pass\nloss.backward()\nprint(f\"Gradients computed - shape: {logits.grad.shape}\")\n```\n\n### Using the Causal LM Loss Module\n\n```python\nimport torch\nfrom crossentropy_triton import TritonCausalLMLoss\n\ndevice = torch.device('cuda')\nvocab_size = 32000\n\n# Initialize the loss function\nloss_fn = TritonCausalLMLoss(vocab_size)\n\n# Create sample data\nlogits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)\nlabels = torch.randint(0, vocab_size, (2, 10), device=device)\n\n# Forward and backward pass\nloss = loss_fn(logits, labels)\nprint(f\"Causal LM loss: {loss.item():.4f}\")\n\nloss.backward()\nprint(f\"Backward pass completed\")\n```\n\n---\n\n## Performance Characteristics\n\n- **Optimized Block Size:** Chooses optimal kernel block sizes up to 32,768.\n- **Memory Fusion:** Fuses softmax and gradient computation in a single kernel.\n- **Efficient Masking:** Ignore index is handled directly in the kernel.\n- **Gradient Scaling:** Proper normalization by non-ignored tokens.\n\n---\n\n## Technical Details\n\n### Kernel Implementation\n\n- **`cross_entropy_kernel`:** Computes the forward pass (loss) and gradients in the logits tensor.\n- **`element_mul_kernel`:** Scales in-place gradients by gradient outputs during backward.\n\n### Memory and Numerical Stability\n\n- Supports both contiguous and non-contiguous tensors.\n- In-place gradient computation for minimal overhead.\n- Log-sum-exp trick for stable softmax.\n\n### Shifted Sequence Handling\n\n- Causal/auto-regressive shifts are built in for next-token prediction.\n\n---\n\n## License\n\nMIT License\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs",
    "version": "0.1.2",
    "project_urls": {
        "Bug Reports": "https://github.com/Dcas89/crossentropy-triton/issues",
        "Homepage": "https://github.com/Dcas89/crossentropy-triton",
        "Source": "https://github.com/Dcas89/crossentropy-triton"
    },
    "split_keywords": [
        "triton",
        " cuda",
        " cross-entropy",
        " machine-learning",
        " deep-learning",
        " pytorch"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "5079adaaa7b32ababaa689d8056bdcade41748ae99cf35e88d4010f8214bf3ea",
                "md5": "a7517b10e4eac1d1f6b77a60495a4f9d",
                "sha256": "58631daa6ad30d6d37b19afc71e40c0debe3967c87d429a4f4a9a3418dd5bccc"
            },
            "downloads": -1,
            "filename": "crossentropy_triton-0.1.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "a7517b10e4eac1d1f6b77a60495a4f9d",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 6497,
            "upload_time": "2025-07-13T22:39:49",
            "upload_time_iso_8601": "2025-07-13T22:39:49.993416Z",
            "url": "https://files.pythonhosted.org/packages/50/79/adaaa7b32ababaa689d8056bdcade41748ae99cf35e88d4010f8214bf3ea/crossentropy_triton-0.1.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "b772e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5",
                "md5": "f5c4cdc50531d222aa39b31cb93dab90",
                "sha256": "7bbd4592b7d9b2add38906607ac54ca396b15f9f1d64d835783e8aea0000ce50"
            },
            "downloads": -1,
            "filename": "crossentropy_triton-0.1.2.tar.gz",
            "has_sig": false,
            "md5_digest": "f5c4cdc50531d222aa39b31cb93dab90",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 7395,
            "upload_time": "2025-07-13T22:39:51",
            "upload_time_iso_8601": "2025-07-13T22:39:51.227293Z",
            "url": "https://files.pythonhosted.org/packages/b7/72/e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5/crossentropy_triton-0.1.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-13 22:39:51",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "Dcas89",
    "github_project": "crossentropy-triton",
    "github_not_found": true,
    "lcname": "crossentropy-triton"
}
        
Elapsed time: 0.58235s