ejkernel


Nameejkernel JSON
Version 0.0.2 PyPI version JSON
download
home_pageNone
SummaryAccelerate, Optimize performance with streamlined training and serving options with JAX.
upload_time2025-10-21 18:38:55
maintainerNone
docs_urlNone
authorErfan Zare Chavoshi
requires_python<3.14,>=3.11
licenseApache-2.0
keywords deep learning machine learning jax cuda xla triton pallas
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # ejKernel: High-Performance JAX Kernels for Deep Learning

> *"The best optimization is the one you don't have to think about."*

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![JAX](https://img.shields.io/badge/JAX-0.7.2+-orange.svg)](https://github.com/google/jax)
[![Documentation](https://img.shields.io/badge/docs-readthedocs-green.svg)](https://ejkernel.readthedocs.io/en/latest/)

ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.

## Table of Contents

- [Key Features](#key-features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Architecture Overview](#architecture-overview)
- [Supported Operations](#supported-operations)
- [Advanced Usage](#advanced-usage)
- [Performance](#performance)
- [Development](#development)
- [Testing](#testing)
- [Contributing](#contributing)
- [Citation](#citation)
- [License](#license)

## Key Features

### Intelligent Kernel Management

- **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance

### State-of-the-Art Operations

- **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, and more
- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
- **Distributed Support**: Full shard_map integration for model and data parallelism
- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion

### Production-Ready Infrastructure

- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring

## Installation

### Basic Installation

```bash
pip install ejkernel
```

### Platform-Specific Installation

```bash
# GPU Support (CUDA/ROCm)
pip install ejkernel[gpu]

# TPU Support
pip install ejkernel[tpu]

# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"
```

### Dependencies

- Python 3.11-3.13
- JAX >= 0.7.2
- Triton == 3.4.0 (for GPU)
- jaxtyping >= 0.3.2
- beartype >= 0.22.2

## Quick Start

### Simple API with Automatic Optimization

```python
import jax.numpy as jnp
from ejkernel.modules import flash_attention

# Basic usage - automatic configuration selection
output = flash_attention(
    query, key, value,
    causal=True,
    dropout_prob=0.1
)

# With advanced features
output = flash_attention(
    query, key, value,
    causal=True,
    sliding_window=128,        # Local attention window
    logits_soft_cap=30.0,     # Gemma-2 style soft capping
    attention_mask=mask,       # Custom attention pattern
)
```

### Custom Configuration

```python
from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams

# Create optimized configuration
config = FlashAttentionConfig(
    fwd_params=FwdParams(
        q_blocksize=256,
        kv_blocksize=256,
        num_warps=8,
        num_stages=2
    ),
    bwd_params=BwdParams(
        q_blocksize=128,
        kv_blocksize=128,
        num_warps=4
    ),
    platform="triton",  # Force specific backend
    backend="gpu"
)

output = flash_attention(query, key, value, cfg=config)
```

### Direct Kernel Registry Access

```python
from ejkernel import kernel_registry, Platform, Backend

# Get specific implementation
kernel = kernel_registry.get(
    algorithm="flash_attention",
    platform=Platform.TRITON,
    backend=Backend.GPU
)

# Direct execution
output = kernel(query, key, value, causal=True)
```

### Distributed Execution

```python
import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention

# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))

# Run distributed attention
output = flash_attention(
    query, key, value,
    causal=True,
    mesh=mesh,
    in_specs=(P("data", None), P("data", None), P("data", None)),
    out_specs=P("data", None)
)
```

## Architecture Overview

### System Design

ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:

```md
┌─────────────────────────────────────────────────────┐
│              Public API (modules/)                   │
│         Simple functions with sensible defaults      │
├─────────────────────────────────────────────────────┤
│            Operations Layer (ops/)                   │
│    Configuration management, autotuning, caching     │
├─────────────────────────────────────────────────────┤
│          Kernel Registry (kernels/)                  │
│      Platform routing, signature validation          │
├─────────────────────────────────────────────────────┤
│      Backend Implementations (kernels/_*)            │
│         Triton, Pallas, XLA, CUDA kernels           │
└─────────────────────────────────────────────────────┘
```

### Project Structure

```md
ejkernel/
├── kernels/
│   ├── _triton/         # GPU kernels via Triton
│   ├── _pallas/         # TPU/GPU kernels via Pallas
│   │   ├── tpu/        # TPU-specific implementations
│   │   └── gpu/        # GPU Pallas implementations
│   ├── _xla/           # Universal XLA implementations
│   └── _cuda/          # Native CUDA kernels
├── modules/
│   └── operations/     # High-level API modules
├── ops/
│   ├── config/         # Configuration management
│   ├── core/           # Base kernel classes
│   ├── execution/      # Execution orchestration
│   └── utils/          # Fingerprinting, utilities
├── xla_utils/          # XLA-specific utilities
└── callib/             # Calibration utilities
```

### Core Components

#### Kernel Registry

The registry provides automatic platform-specific kernel selection:

```python
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
    # GPU-optimized implementation
    pass

@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
    # Universal fallback
    pass

# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")
```

#### Configuration Management

Multi-tier configuration system with intelligent fallback:

```python
class ConfigSelectorChain:
    """
    Selection hierarchy:
    1. Override - Explicit user configuration
    2. Overlay - Temporary context overrides
    3. Memory Cache - In-memory lookup
    4. Persistent Cache - Disk-based storage
    5. Autotune - Performance benchmarking
    6. Heuristics - Intelligent defaults
    7. Error - Clear failure message
    """
```

#### Custom VJP System

All performance-critical kernels implement memory-efficient gradients:

```python
@jax.custom_vjp
def kernel_with_custom_grad(inputs):
    return forward(inputs)

def kernel_fwd(inputs):
    output, residuals = forward_with_residuals(inputs)
    return output, residuals

def kernel_bwd(residuals, grad_output):
    return efficient_backward(residuals, grad_output)

kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
```

## Supported Operations

### Attention Mechanisms

| Algorithm | Description | Memory | Key Features |
|-----------|-------------|--------|--------------|
| **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
| **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
| **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
| **Block Sparse Attention** | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
| **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |
| **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |
| **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |
| **Ragged Attention** | Variable-length sequences | O(N) | Efficient padding, batched inference |

### Other Operations

- **Recurrent Kernels**: Optimized RNN/LSTM/GRU operations
- **Mean Pooling**: Variable-length sequence aggregation
- **Grouped MatMul**: Efficient batched matrix operations
- **Native Sparse**: Block-sparse matrix computations

### Platform Support Matrix

| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) | CUDA |
|-----------|-------------|--------------|-----------------|------|
| Flash Attention v2 | ✓ | ✓ | ✓ | Dev |
| Ring Attention | ✓ | ✓ | ✓ | Dev |
| Page Attention | ✓ | ✓ | ✓ | Dev |
| Block Sparse | ✓ | - | ✓ | Dev |
| GLA | ✓ | Dev | ✓ | - |
| Lightning | ✓ | - | ✓ | Dev |
| MLA | ✓ | Dev | - | - |
| Ragged Attention | ✓ | ✓ | ✓ | Dev |

✓ = Production ready | Dev = Under development | - = Not planned

## Advanced Usage

### Performance Optimization

```python
# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"

# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json"  # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1"    # Record invocations
```

### Custom Kernel Development

```python
from ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig

@dataclass
class MyConfig(BaseOperationConfig):
    param1: int = 128
    param2: float = 0.1

class MyKernel(Kernel[MyConfig, Array]):
    def __init__(self):
        super().__init__(op_id="my_kernel")

    def run(self, x, cfg: MyConfig):
        impl = kernel_registry.get("my_kernel", cfg.platform)
        return impl(x, param1=cfg.param1, param2=cfg.param2)

    def heuristic_cfg(self, inv):
        # Return default configuration
        return MyConfig(param1=256)

    def candidate_cfgs(self, inv):
        # Return autotuning candidates
        return [MyConfig(param1=p) for p in [64, 128, 256]]
```

### Integration with Models

```python
import flax.linen as nn

class TransformerBlock(nn.Module):
    num_heads: int = 8
    head_dim: int = 64

    @nn.compact
    def __call__(self, x, mask=None):
        # Project to Q, K, V
        q = nn.Dense(self.num_heads * self.head_dim)(x)
        k = nn.Dense(self.num_heads * self.head_dim)(x)
        v = nn.Dense(self.num_heads * self.head_dim)(x)

        # Reshape for attention
        shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
        q, k, v = map(lambda t: t.reshape(shape), (q, k, v))

        # Apply ejKernel Flash Attention
        attn_output = flash_attention(
            q, k, v,
            causal=True,
            attention_mask=mask
        )

        # Project output
        return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
```

## Performance

## Development

### Setting Up Development Environment

```bash
# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel

# Create virtual environment
python -m venv .venv
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

# Install in development mode
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install
```

### Code Style

The project uses:

- **black** for code formatting (line length: 121)
- **ruff** for linting
- **mypy/pyright** for type checking
- **pre-commit** for automated checks

### Adding New Kernels

1. **Implement the kernel** in appropriate backend directory:

```python
# ejkernel/kernels/_triton/my_kernel.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
    # Implementation
    pass
```

2 **Create module wrapper**:

```python
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
    # Module implementation
    pass
```

3 **Add tests**:

```python
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
    # Test implementation
    pass
```

4 **Update documentation**

## Testing

### Running Tests

```bash
# Run all tests
python test/run_tests.py

# Platform-specific tests
python test/run_tests.py --xla      # XLA implementations
python test/run_tests.py --triton   # Triton implementations
python test/run_tests.py --pallas   # Pallas implementations

# Cross-platform validation
python test/run_tests.py --comparison

# Specific test patterns
python test/run_tests.py -k "flash_attention"
python test/run_tests.py --verbose --failfast
```

### Test Categories

- **Unit Tests**: Individual component testing
- **Integration Tests**: End-to-end workflows
- **Comparison Tests**: Cross-backend consistency
- **Performance Tests**: Regression detection
- **Property Tests**: Invariant verification

### Continuous Integration

The project uses GitHub Actions for CI with tests across:

- Multiple Python versions (3.11, 3.12, 3.13)
- Multiple platforms (CPU, GPU, TPU)
- Multiple JAX versions

## Contributing

We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.

### Priority Areas

- TPU/Pallas implementations for existing algorithms
- CUDA native kernels for maximum performance
- New attention mechanisms from recent papers
- Performance optimizations and kernel fusion
- Documentation and examples

### Contribution Process

1. Fork the repository
2. Create a feature branch
3. Implement your changes with tests
4. Ensure all tests pass
5. Submit a pull request

## Documentation

Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)

- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis

## Citation

If you use ejKernel in your research, please cite:

```bibtex
@software{ejkernel2024,
  author = {Erfan Zare Chavoshi},
  title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
  year = {2024},
  url = {https://github.com/erfanzar/ejkernel},
  note = {Production-grade kernel library with multi-backend support}
}
```

## License

ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.

## Acknowledgments

ejKernel builds upon excellent work from:

- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
- [Triton](https://github.com/openai/triton) - GPU kernel programming language
- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning

## Community

- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
- **Email**: <Erfanzare810@gmail.com>

## Roadmap

### Near Term (Q1 2025)

- Flash Attention 3 implementation
- Complete CUDA backend
- Quantized attention (INT8/INT4)
- Fused operations (LayerNorm+Attention)

### Medium Term (Q2-Q3 2025)

- Speculative decoding support
- Continuous batching
- Mamba SSM kernels

### Long Term (Q4 2025+)

- Multi-GPU kernel fusion
- Automatic kernel selection ML model
- Custom DSL for kernel development
- Hardware-agnostic IR

---

ejKernel - Production-grade kernels for JAX deep learning

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "ejkernel",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<3.14,>=3.11",
    "maintainer_email": null,
    "keywords": "Deep Learning, Machine Learning, JAX, CUDA, XLA, Triton, Pallas",
    "author": "Erfan Zare Chavoshi",
    "author_email": "Erfan Zare Chavoshi <Erfanzare810@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/a5/4f/ca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec/ejkernel-0.0.2.tar.gz",
    "platform": null,
    "description": "# ejKernel: High-Performance JAX Kernels for Deep Learning\n\n> *\"The best optimization is the one you don't have to think about.\"*\n\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)\n[![JAX](https://img.shields.io/badge/JAX-0.7.2+-orange.svg)](https://github.com/google/jax)\n[![Documentation](https://img.shields.io/badge/docs-readthedocs-green.svg)](https://ejkernel.readthedocs.io/en/latest/)\n\nejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.\n\n## Table of Contents\n\n- [Key Features](#key-features)\n- [Installation](#installation)\n- [Quick Start](#quick-start)\n- [Architecture Overview](#architecture-overview)\n- [Supported Operations](#supported-operations)\n- [Advanced Usage](#advanced-usage)\n- [Performance](#performance)\n- [Development](#development)\n- [Testing](#testing)\n- [Contributing](#contributing)\n- [Citation](#citation)\n- [License](#license)\n\n## Key Features\n\n### Intelligent Kernel Management\n\n- **7-Tier Configuration System**: Override \u2192 Overlay \u2192 Memory Cache \u2192 Persistent Cache \u2192 Autotune \u2192 Heuristics \u2192 Error\n- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware\n- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms\n- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance\n\n### State-of-the-Art Operations\n\n- **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, and more\n- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention\n- **Distributed Support**: Full shard_map integration for model and data parallelism\n- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion\n\n### Production-Ready Infrastructure\n\n- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype\n- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests\n- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization\n- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring\n\n## Installation\n\n### Basic Installation\n\n```bash\npip install ejkernel\n```\n\n### Platform-Specific Installation\n\n```bash\n# GPU Support (CUDA/ROCm)\npip install ejkernel[gpu]\n\n# TPU Support\npip install ejkernel[tpu]\n\n# Development Installation\ngit clone https://github.com/erfanzar/ejkernel.git\ncd ejkernel\npip install -e \".[dev]\"\n```\n\n### Dependencies\n\n- Python 3.11-3.13\n- JAX >= 0.7.2\n- Triton == 3.4.0 (for GPU)\n- jaxtyping >= 0.3.2\n- beartype >= 0.22.2\n\n## Quick Start\n\n### Simple API with Automatic Optimization\n\n```python\nimport jax.numpy as jnp\nfrom ejkernel.modules import flash_attention\n\n# Basic usage - automatic configuration selection\noutput = flash_attention(\n    query, key, value,\n    causal=True,\n    dropout_prob=0.1\n)\n\n# With advanced features\noutput = flash_attention(\n    query, key, value,\n    causal=True,\n    sliding_window=128,        # Local attention window\n    logits_soft_cap=30.0,     # Gemma-2 style soft capping\n    attention_mask=mask,       # Custom attention pattern\n)\n```\n\n### Custom Configuration\n\n```python\nfrom ejkernel.modules import FlashAttentionConfig\nfrom ejkernel.ops.utils.datacarrier import FwdParams, BwdParams\n\n# Create optimized configuration\nconfig = FlashAttentionConfig(\n    fwd_params=FwdParams(\n        q_blocksize=256,\n        kv_blocksize=256,\n        num_warps=8,\n        num_stages=2\n    ),\n    bwd_params=BwdParams(\n        q_blocksize=128,\n        kv_blocksize=128,\n        num_warps=4\n    ),\n    platform=\"triton\",  # Force specific backend\n    backend=\"gpu\"\n)\n\noutput = flash_attention(query, key, value, cfg=config)\n```\n\n### Direct Kernel Registry Access\n\n```python\nfrom ejkernel import kernel_registry, Platform, Backend\n\n# Get specific implementation\nkernel = kernel_registry.get(\n    algorithm=\"flash_attention\",\n    platform=Platform.TRITON,\n    backend=Backend.GPU\n)\n\n# Direct execution\noutput = kernel(query, key, value, causal=True)\n```\n\n### Distributed Execution\n\n```python\nimport jax\nfrom jax.sharding import Mesh, PartitionSpec as P\nfrom ejkernel.modules import flash_attention\n\n# Setup mesh for distributed execution\ndevices = jax.devices()\nmesh = Mesh(devices, axis_names=(\"data\", \"model\"))\n\n# Run distributed attention\noutput = flash_attention(\n    query, key, value,\n    causal=True,\n    mesh=mesh,\n    in_specs=(P(\"data\", None), P(\"data\", None), P(\"data\", None)),\n    out_specs=P(\"data\", None)\n)\n```\n\n## Architecture Overview\n\n### System Design\n\nejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:\n\n```md\n\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502              Public API (modules/)                   \u2502\n\u2502         Simple functions with sensible defaults      \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502            Operations Layer (ops/)                   \u2502\n\u2502    Configuration management, autotuning, caching     \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502          Kernel Registry (kernels/)                  \u2502\n\u2502      Platform routing, signature validation          \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502      Backend Implementations (kernels/_*)            \u2502\n\u2502         Triton, Pallas, XLA, CUDA kernels           \u2502\n\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n### Project Structure\n\n```md\nejkernel/\n\u251c\u2500\u2500 kernels/\n\u2502   \u251c\u2500\u2500 _triton/         # GPU kernels via Triton\n\u2502   \u251c\u2500\u2500 _pallas/         # TPU/GPU kernels via Pallas\n\u2502   \u2502   \u251c\u2500\u2500 tpu/        # TPU-specific implementations\n\u2502   \u2502   \u2514\u2500\u2500 gpu/        # GPU Pallas implementations\n\u2502   \u251c\u2500\u2500 _xla/           # Universal XLA implementations\n\u2502   \u2514\u2500\u2500 _cuda/          # Native CUDA kernels\n\u251c\u2500\u2500 modules/\n\u2502   \u2514\u2500\u2500 operations/     # High-level API modules\n\u251c\u2500\u2500 ops/\n\u2502   \u251c\u2500\u2500 config/         # Configuration management\n\u2502   \u251c\u2500\u2500 core/           # Base kernel classes\n\u2502   \u251c\u2500\u2500 execution/      # Execution orchestration\n\u2502   \u2514\u2500\u2500 utils/          # Fingerprinting, utilities\n\u251c\u2500\u2500 xla_utils/          # XLA-specific utilities\n\u2514\u2500\u2500 callib/             # Calibration utilities\n```\n\n### Core Components\n\n#### Kernel Registry\n\nThe registry provides automatic platform-specific kernel selection:\n\n```python\n@kernel_registry.register(\"my_operation\", Platform.TRITON, Backend.GPU, priority=100)\ndef my_operation_gpu(x, y):\n    # GPU-optimized implementation\n    pass\n\n@kernel_registry.register(\"my_operation\", Platform.XLA, Backend.ANY, priority=50)\ndef my_operation_fallback(x, y):\n    # Universal fallback\n    pass\n\n# Automatic selection based on available hardware\nimpl = kernel_registry.get(\"my_operation\")\n```\n\n#### Configuration Management\n\nMulti-tier configuration system with intelligent fallback:\n\n```python\nclass ConfigSelectorChain:\n    \"\"\"\n    Selection hierarchy:\n    1. Override - Explicit user configuration\n    2. Overlay - Temporary context overrides\n    3. Memory Cache - In-memory lookup\n    4. Persistent Cache - Disk-based storage\n    5. Autotune - Performance benchmarking\n    6. Heuristics - Intelligent defaults\n    7. Error - Clear failure message\n    \"\"\"\n```\n\n#### Custom VJP System\n\nAll performance-critical kernels implement memory-efficient gradients:\n\n```python\n@jax.custom_vjp\ndef kernel_with_custom_grad(inputs):\n    return forward(inputs)\n\ndef kernel_fwd(inputs):\n    output, residuals = forward_with_residuals(inputs)\n    return output, residuals\n\ndef kernel_bwd(residuals, grad_output):\n    return efficient_backward(residuals, grad_output)\n\nkernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)\n```\n\n## Supported Operations\n\n### Attention Mechanisms\n\n| Algorithm | Description | Memory | Key Features |\n|-----------|-------------|--------|--------------|\n| **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |\n| **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |\n| **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |\n| **Block Sparse Attention** | Configurable sparse patterns | O(N\u221aN) | Local+global, custom patterns |\n| **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |\n| **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |\n| **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |\n| **Ragged Attention** | Variable-length sequences | O(N) | Efficient padding, batched inference |\n\n### Other Operations\n\n- **Recurrent Kernels**: Optimized RNN/LSTM/GRU operations\n- **Mean Pooling**: Variable-length sequence aggregation\n- **Grouped MatMul**: Efficient batched matrix operations\n- **Native Sparse**: Block-sparse matrix computations\n\n### Platform Support Matrix\n\n| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) | CUDA |\n|-----------|-------------|--------------|-----------------|------|\n| Flash Attention v2 | \u2713 | \u2713 | \u2713 | Dev |\n| Ring Attention | \u2713 | \u2713 | \u2713 | Dev |\n| Page Attention | \u2713 | \u2713 | \u2713 | Dev |\n| Block Sparse | \u2713 | - | \u2713 | Dev |\n| GLA | \u2713 | Dev | \u2713 | - |\n| Lightning | \u2713 | - | \u2713 | Dev |\n| MLA | \u2713 | Dev | - | - |\n| Ragged Attention | \u2713 | \u2713 | \u2713 | Dev |\n\n\u2713 = Production ready | Dev = Under development | - = Not planned\n\n## Advanced Usage\n\n### Performance Optimization\n\n```python\n# Force autotuning for optimal configuration\nimport os\nos.environ[\"EJKERNEL_AUTOTUNE_POLICY\"] = \"autotune\"\nos.environ[\"EJKERNEL_LOG_AUTOTUNE\"] = \"1\"\n\n# Enable profiling\nos.environ[\"EJKERNEL_OPS_STAMP\"] = \"json\"  # Detailed metadata\nos.environ[\"EJKERNEL_OPS_RECORD\"] = \"1\"    # Record invocations\n```\n\n### Custom Kernel Development\n\n```python\nfrom ejkernel.ops.core import Kernel\nfrom ejkernel.modules.operations.configs import BaseOperationConfig\n\n@dataclass\nclass MyConfig(BaseOperationConfig):\n    param1: int = 128\n    param2: float = 0.1\n\nclass MyKernel(Kernel[MyConfig, Array]):\n    def __init__(self):\n        super().__init__(op_id=\"my_kernel\")\n\n    def run(self, x, cfg: MyConfig):\n        impl = kernel_registry.get(\"my_kernel\", cfg.platform)\n        return impl(x, param1=cfg.param1, param2=cfg.param2)\n\n    def heuristic_cfg(self, inv):\n        # Return default configuration\n        return MyConfig(param1=256)\n\n    def candidate_cfgs(self, inv):\n        # Return autotuning candidates\n        return [MyConfig(param1=p) for p in [64, 128, 256]]\n```\n\n### Integration with Models\n\n```python\nimport flax.linen as nn\n\nclass TransformerBlock(nn.Module):\n    num_heads: int = 8\n    head_dim: int = 64\n\n    @nn.compact\n    def __call__(self, x, mask=None):\n        # Project to Q, K, V\n        q = nn.Dense(self.num_heads * self.head_dim)(x)\n        k = nn.Dense(self.num_heads * self.head_dim)(x)\n        v = nn.Dense(self.num_heads * self.head_dim)(x)\n\n        # Reshape for attention\n        shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)\n        q, k, v = map(lambda t: t.reshape(shape), (q, k, v))\n\n        # Apply ejKernel Flash Attention\n        attn_output = flash_attention(\n            q, k, v,\n            causal=True,\n            attention_mask=mask\n        )\n\n        # Project output\n        return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))\n```\n\n## Performance\n\n## Development\n\n### Setting Up Development Environment\n\n```bash\n# Clone repository\ngit clone https://github.com/erfanzar/ejkernel.git\ncd ejkernel\n\n# Create virtual environment\npython -m venv .venv\nsource .venv/bin/activate  # On Windows: .venv\\Scripts\\activate\n\n# Install in development mode\npip install -e \".[dev]\"\n\n# Install pre-commit hooks\npre-commit install\n```\n\n### Code Style\n\nThe project uses:\n\n- **black** for code formatting (line length: 121)\n- **ruff** for linting\n- **mypy/pyright** for type checking\n- **pre-commit** for automated checks\n\n### Adding New Kernels\n\n1. **Implement the kernel** in appropriate backend directory:\n\n```python\n# ejkernel/kernels/_triton/my_kernel.py\n@kernel_registry.register(\"my_kernel\", Platform.TRITON, Backend.GPU)\ndef my_kernel_triton(x, config):\n    # Implementation\n    pass\n```\n\n2 **Create module wrapper**:\n\n```python\n# ejkernel/modules/operations/my_kernel.py\nclass MyKernel(Kernel[MyKernelConfig, Array]):\n    # Module implementation\n    pass\n```\n\n3 **Add tests**:\n\n```python\n# test/kernels/_triton/test_my_kernel.py\nclass TestMyKernel(unittest.TestCase):\n    # Test implementation\n    pass\n```\n\n4 **Update documentation**\n\n## Testing\n\n### Running Tests\n\n```bash\n# Run all tests\npython test/run_tests.py\n\n# Platform-specific tests\npython test/run_tests.py --xla      # XLA implementations\npython test/run_tests.py --triton   # Triton implementations\npython test/run_tests.py --pallas   # Pallas implementations\n\n# Cross-platform validation\npython test/run_tests.py --comparison\n\n# Specific test patterns\npython test/run_tests.py -k \"flash_attention\"\npython test/run_tests.py --verbose --failfast\n```\n\n### Test Categories\n\n- **Unit Tests**: Individual component testing\n- **Integration Tests**: End-to-end workflows\n- **Comparison Tests**: Cross-backend consistency\n- **Performance Tests**: Regression detection\n- **Property Tests**: Invariant verification\n\n### Continuous Integration\n\nThe project uses GitHub Actions for CI with tests across:\n\n- Multiple Python versions (3.11, 3.12, 3.13)\n- Multiple platforms (CPU, GPU, TPU)\n- Multiple JAX versions\n\n## Contributing\n\nWe welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.\n\n### Priority Areas\n\n- TPU/Pallas implementations for existing algorithms\n- CUDA native kernels for maximum performance\n- New attention mechanisms from recent papers\n- Performance optimizations and kernel fusion\n- Documentation and examples\n\n### Contribution Process\n\n1. Fork the repository\n2. Create a feature branch\n3. Implement your changes with tests\n4. Ensure all tests pass\n5. Submit a pull request\n\n## Documentation\n\nComprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)\n\n- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation\n- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides\n- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation\n- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis\n\n## Citation\n\nIf you use ejKernel in your research, please cite:\n\n```bibtex\n@software{ejkernel2024,\n  author = {Erfan Zare Chavoshi},\n  title = {ejKernel: High-Performance JAX Kernels for Deep Learning},\n  year = {2024},\n  url = {https://github.com/erfanzar/ejkernel},\n  note = {Production-grade kernel library with multi-backend support}\n}\n```\n\n## License\n\nejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.\n\n## Acknowledgments\n\nejKernel builds upon excellent work from:\n\n- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs\n- [Triton](https://github.com/openai/triton) - GPU kernel programming language\n- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language\n- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention\n- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning\n\n## Community\n\n- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)\n- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)\n- **Email**: <Erfanzare810@gmail.com>\n\n## Roadmap\n\n### Near Term (Q1 2025)\n\n- Flash Attention 3 implementation\n- Complete CUDA backend\n- Quantized attention (INT8/INT4)\n- Fused operations (LayerNorm+Attention)\n\n### Medium Term (Q2-Q3 2025)\n\n- Speculative decoding support\n- Continuous batching\n- Mamba SSM kernels\n\n### Long Term (Q4 2025+)\n\n- Multi-GPU kernel fusion\n- Automatic kernel selection ML model\n- Custom DSL for kernel development\n- Hardware-agnostic IR\n\n---\n\nejKernel - Production-grade kernels for JAX deep learning\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "Accelerate, Optimize performance with streamlined training and serving options with JAX.",
    "version": "0.0.2",
    "project_urls": {
        "Documentation": "https://ejkernel.readthedocs.io/en/latest/",
        "Homepage": "https://github.com/erfanzar/ejkernel",
        "Repository": "https://github.com/erfanzar/ejkernel"
    },
    "split_keywords": [
        "deep learning",
        " machine learning",
        " jax",
        " cuda",
        " xla",
        " triton",
        " pallas"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "a42fab1ef45119ba016f3d96de7b8a07b53d974c74b1e1f7996a79d13ecfce04",
                "md5": "3d8db136b2c1ef1d346efc883a843431",
                "sha256": "6b5af9c4f07e14e64ba5725a0e36a1c76b93fb5953bd7fc9fd9545d8b93da199"
            },
            "downloads": -1,
            "filename": "ejkernel-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "3d8db136b2c1ef1d346efc883a843431",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<3.14,>=3.11",
            "size": 556624,
            "upload_time": "2025-10-21T18:38:53",
            "upload_time_iso_8601": "2025-10-21T18:38:53.402048Z",
            "url": "https://files.pythonhosted.org/packages/a4/2f/ab1ef45119ba016f3d96de7b8a07b53d974c74b1e1f7996a79d13ecfce04/ejkernel-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "a54fca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec",
                "md5": "ab964fe6d94062e485760eecaa12c959",
                "sha256": "b58b3e9ef0d8b38a12f0c2f6fa342a6e5c08fe4bcec27b4193d7102d22ae3734"
            },
            "downloads": -1,
            "filename": "ejkernel-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "ab964fe6d94062e485760eecaa12c959",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<3.14,>=3.11",
            "size": 374778,
            "upload_time": "2025-10-21T18:38:55",
            "upload_time_iso_8601": "2025-10-21T18:38:55.150828Z",
            "url": "https://files.pythonhosted.org/packages/a5/4f/ca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec/ejkernel-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-10-21 18:38:55",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "erfanzar",
    "github_project": "ejkernel",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "ejkernel"
}
        
Elapsed time: 2.18807s