# ejKernel: High-Performance JAX Kernels for Deep Learning
> *"The best optimization is the one you don't have to think about."*
[](https://opensource.org/licenses/Apache-2.0)
[](https://www.python.org/downloads/)
[](https://github.com/google/jax)
[](https://ejkernel.readthedocs.io/en/latest/)
ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
## Table of Contents
- [Key Features](#key-features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Architecture Overview](#architecture-overview)
- [Supported Operations](#supported-operations)
- [Advanced Usage](#advanced-usage)
- [Performance](#performance)
- [Development](#development)
- [Testing](#testing)
- [Contributing](#contributing)
- [Citation](#citation)
- [License](#license)
## Key Features
### Intelligent Kernel Management
- **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance
### State-of-the-Art Operations
- **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, and more
- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
- **Distributed Support**: Full shard_map integration for model and data parallelism
- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion
### Production-Ready Infrastructure
- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring
## Installation
### Basic Installation
```bash
pip install ejkernel
```
### Platform-Specific Installation
```bash
# GPU Support (CUDA/ROCm)
pip install ejkernel[gpu]
# TPU Support
pip install ejkernel[tpu]
# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"
```
### Dependencies
- Python 3.11-3.13
- JAX >= 0.7.2
- Triton == 3.4.0 (for GPU)
- jaxtyping >= 0.3.2
- beartype >= 0.22.2
## Quick Start
### Simple API with Automatic Optimization
```python
import jax.numpy as jnp
from ejkernel.modules import flash_attention
# Basic usage - automatic configuration selection
output = flash_attention(
query, key, value,
causal=True,
dropout_prob=0.1
)
# With advanced features
output = flash_attention(
query, key, value,
causal=True,
sliding_window=128, # Local attention window
logits_soft_cap=30.0, # Gemma-2 style soft capping
attention_mask=mask, # Custom attention pattern
)
```
### Custom Configuration
```python
from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
# Create optimized configuration
config = FlashAttentionConfig(
fwd_params=FwdParams(
q_blocksize=256,
kv_blocksize=256,
num_warps=8,
num_stages=2
),
bwd_params=BwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=4
),
platform="triton", # Force specific backend
backend="gpu"
)
output = flash_attention(query, key, value, cfg=config)
```
### Direct Kernel Registry Access
```python
from ejkernel import kernel_registry, Platform, Backend
# Get specific implementation
kernel = kernel_registry.get(
algorithm="flash_attention",
platform=Platform.TRITON,
backend=Backend.GPU
)
# Direct execution
output = kernel(query, key, value, causal=True)
```
### Distributed Execution
```python
import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention
# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))
# Run distributed attention
output = flash_attention(
query, key, value,
causal=True,
mesh=mesh,
in_specs=(P("data", None), P("data", None), P("data", None)),
out_specs=P("data", None)
)
```
## Architecture Overview
### System Design
ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
```md
┌─────────────────────────────────────────────────────┐
│ Public API (modules/) │
│ Simple functions with sensible defaults │
├─────────────────────────────────────────────────────┤
│ Operations Layer (ops/) │
│ Configuration management, autotuning, caching │
├─────────────────────────────────────────────────────┤
│ Kernel Registry (kernels/) │
│ Platform routing, signature validation │
├─────────────────────────────────────────────────────┤
│ Backend Implementations (kernels/_*) │
│ Triton, Pallas, XLA, CUDA kernels │
└─────────────────────────────────────────────────────┘
```
### Project Structure
```md
ejkernel/
├── kernels/
│ ├── _triton/ # GPU kernels via Triton
│ ├── _pallas/ # TPU/GPU kernels via Pallas
│ │ ├── tpu/ # TPU-specific implementations
│ │ └── gpu/ # GPU Pallas implementations
│ ├── _xla/ # Universal XLA implementations
│ └── _cuda/ # Native CUDA kernels
├── modules/
│ └── operations/ # High-level API modules
├── ops/
│ ├── config/ # Configuration management
│ ├── core/ # Base kernel classes
│ ├── execution/ # Execution orchestration
│ └── utils/ # Fingerprinting, utilities
├── xla_utils/ # XLA-specific utilities
└── callib/ # Calibration utilities
```
### Core Components
#### Kernel Registry
The registry provides automatic platform-specific kernel selection:
```python
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
# GPU-optimized implementation
pass
@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
# Universal fallback
pass
# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")
```
#### Configuration Management
Multi-tier configuration system with intelligent fallback:
```python
class ConfigSelectorChain:
"""
Selection hierarchy:
1. Override - Explicit user configuration
2. Overlay - Temporary context overrides
3. Memory Cache - In-memory lookup
4. Persistent Cache - Disk-based storage
5. Autotune - Performance benchmarking
6. Heuristics - Intelligent defaults
7. Error - Clear failure message
"""
```
#### Custom VJP System
All performance-critical kernels implement memory-efficient gradients:
```python
@jax.custom_vjp
def kernel_with_custom_grad(inputs):
return forward(inputs)
def kernel_fwd(inputs):
output, residuals = forward_with_residuals(inputs)
return output, residuals
def kernel_bwd(residuals, grad_output):
return efficient_backward(residuals, grad_output)
kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
```
## Supported Operations
### Attention Mechanisms
| Algorithm | Description | Memory | Key Features |
|-----------|-------------|--------|--------------|
| **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
| **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
| **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
| **Block Sparse Attention** | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
| **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |
| **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |
| **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |
| **Ragged Attention** | Variable-length sequences | O(N) | Efficient padding, batched inference |
### Other Operations
- **Recurrent Kernels**: Optimized RNN/LSTM/GRU operations
- **Mean Pooling**: Variable-length sequence aggregation
- **Grouped MatMul**: Efficient batched matrix operations
- **Native Sparse**: Block-sparse matrix computations
### Platform Support Matrix
| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) | CUDA |
|-----------|-------------|--------------|-----------------|------|
| Flash Attention v2 | ✓ | ✓ | ✓ | Dev |
| Ring Attention | ✓ | ✓ | ✓ | Dev |
| Page Attention | ✓ | ✓ | ✓ | Dev |
| Block Sparse | ✓ | - | ✓ | Dev |
| GLA | ✓ | Dev | ✓ | - |
| Lightning | ✓ | - | ✓ | Dev |
| MLA | ✓ | Dev | - | - |
| Ragged Attention | ✓ | ✓ | ✓ | Dev |
✓ = Production ready | Dev = Under development | - = Not planned
## Advanced Usage
### Performance Optimization
```python
# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocations
```
### Custom Kernel Development
```python
from ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig
@dataclass
class MyConfig(BaseOperationConfig):
param1: int = 128
param2: float = 0.1
class MyKernel(Kernel[MyConfig, Array]):
def __init__(self):
super().__init__(op_id="my_kernel")
def run(self, x, cfg: MyConfig):
impl = kernel_registry.get("my_kernel", cfg.platform)
return impl(x, param1=cfg.param1, param2=cfg.param2)
def heuristic_cfg(self, inv):
# Return default configuration
return MyConfig(param1=256)
def candidate_cfgs(self, inv):
# Return autotuning candidates
return [MyConfig(param1=p) for p in [64, 128, 256]]
```
### Integration with Models
```python
import flax.linen as nn
class TransformerBlock(nn.Module):
num_heads: int = 8
head_dim: int = 64
@nn.compact
def __call__(self, x, mask=None):
# Project to Q, K, V
q = nn.Dense(self.num_heads * self.head_dim)(x)
k = nn.Dense(self.num_heads * self.head_dim)(x)
v = nn.Dense(self.num_heads * self.head_dim)(x)
# Reshape for attention
shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
# Apply ejKernel Flash Attention
attn_output = flash_attention(
q, k, v,
causal=True,
attention_mask=mask
)
# Project output
return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
```
## Performance
## Development
### Setting Up Development Environment
```bash
# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install in development mode
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit install
```
### Code Style
The project uses:
- **black** for code formatting (line length: 121)
- **ruff** for linting
- **mypy/pyright** for type checking
- **pre-commit** for automated checks
### Adding New Kernels
1. **Implement the kernel** in appropriate backend directory:
```python
# ejkernel/kernels/_triton/my_kernel.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
# Implementation
pass
```
2 **Create module wrapper**:
```python
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
# Module implementation
pass
```
3 **Add tests**:
```python
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
# Test implementation
pass
```
4 **Update documentation**
## Testing
### Running Tests
```bash
# Run all tests
python test/run_tests.py
# Platform-specific tests
python test/run_tests.py --xla # XLA implementations
python test/run_tests.py --triton # Triton implementations
python test/run_tests.py --pallas # Pallas implementations
# Cross-platform validation
python test/run_tests.py --comparison
# Specific test patterns
python test/run_tests.py -k "flash_attention"
python test/run_tests.py --verbose --failfast
```
### Test Categories
- **Unit Tests**: Individual component testing
- **Integration Tests**: End-to-end workflows
- **Comparison Tests**: Cross-backend consistency
- **Performance Tests**: Regression detection
- **Property Tests**: Invariant verification
### Continuous Integration
The project uses GitHub Actions for CI with tests across:
- Multiple Python versions (3.11, 3.12, 3.13)
- Multiple platforms (CPU, GPU, TPU)
- Multiple JAX versions
## Contributing
We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
### Priority Areas
- TPU/Pallas implementations for existing algorithms
- CUDA native kernels for maximum performance
- New attention mechanisms from recent papers
- Performance optimizations and kernel fusion
- Documentation and examples
### Contribution Process
1. Fork the repository
2. Create a feature branch
3. Implement your changes with tests
4. Ensure all tests pass
5. Submit a pull request
## Documentation
Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)
- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis
## Citation
If you use ejKernel in your research, please cite:
```bibtex
@software{ejkernel2024,
author = {Erfan Zare Chavoshi},
title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
year = {2024},
url = {https://github.com/erfanzar/ejkernel},
note = {Production-grade kernel library with multi-backend support}
}
```
## License
ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
## Acknowledgments
ejKernel builds upon excellent work from:
- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
- [Triton](https://github.com/openai/triton) - GPU kernel programming language
- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning
## Community
- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
- **Email**: <Erfanzare810@gmail.com>
## Roadmap
### Near Term (Q1 2025)
- Flash Attention 3 implementation
- Complete CUDA backend
- Quantized attention (INT8/INT4)
- Fused operations (LayerNorm+Attention)
### Medium Term (Q2-Q3 2025)
- Speculative decoding support
- Continuous batching
- Mamba SSM kernels
### Long Term (Q4 2025+)
- Multi-GPU kernel fusion
- Automatic kernel selection ML model
- Custom DSL for kernel development
- Hardware-agnostic IR
---
ejKernel - Production-grade kernels for JAX deep learning
Raw data
{
"_id": null,
"home_page": null,
"name": "ejkernel",
"maintainer": null,
"docs_url": null,
"requires_python": "<3.14,>=3.11",
"maintainer_email": null,
"keywords": "Deep Learning, Machine Learning, JAX, CUDA, XLA, Triton, Pallas",
"author": "Erfan Zare Chavoshi",
"author_email": "Erfan Zare Chavoshi <Erfanzare810@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/a5/4f/ca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec/ejkernel-0.0.2.tar.gz",
"platform": null,
"description": "# ejKernel: High-Performance JAX Kernels for Deep Learning\n\n> *\"The best optimization is the one you don't have to think about.\"*\n\n[](https://opensource.org/licenses/Apache-2.0)\n[](https://www.python.org/downloads/)\n[](https://github.com/google/jax)\n[](https://ejkernel.readthedocs.io/en/latest/)\n\nejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.\n\n## Table of Contents\n\n- [Key Features](#key-features)\n- [Installation](#installation)\n- [Quick Start](#quick-start)\n- [Architecture Overview](#architecture-overview)\n- [Supported Operations](#supported-operations)\n- [Advanced Usage](#advanced-usage)\n- [Performance](#performance)\n- [Development](#development)\n- [Testing](#testing)\n- [Contributing](#contributing)\n- [Citation](#citation)\n- [License](#license)\n\n## Key Features\n\n### Intelligent Kernel Management\n\n- **7-Tier Configuration System**: Override \u2192 Overlay \u2192 Memory Cache \u2192 Persistent Cache \u2192 Autotune \u2192 Heuristics \u2192 Error\n- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware\n- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms\n- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance\n\n### State-of-the-Art Operations\n\n- **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, and more\n- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention\n- **Distributed Support**: Full shard_map integration for model and data parallelism\n- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion\n\n### Production-Ready Infrastructure\n\n- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype\n- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests\n- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization\n- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring\n\n## Installation\n\n### Basic Installation\n\n```bash\npip install ejkernel\n```\n\n### Platform-Specific Installation\n\n```bash\n# GPU Support (CUDA/ROCm)\npip install ejkernel[gpu]\n\n# TPU Support\npip install ejkernel[tpu]\n\n# Development Installation\ngit clone https://github.com/erfanzar/ejkernel.git\ncd ejkernel\npip install -e \".[dev]\"\n```\n\n### Dependencies\n\n- Python 3.11-3.13\n- JAX >= 0.7.2\n- Triton == 3.4.0 (for GPU)\n- jaxtyping >= 0.3.2\n- beartype >= 0.22.2\n\n## Quick Start\n\n### Simple API with Automatic Optimization\n\n```python\nimport jax.numpy as jnp\nfrom ejkernel.modules import flash_attention\n\n# Basic usage - automatic configuration selection\noutput = flash_attention(\n query, key, value,\n causal=True,\n dropout_prob=0.1\n)\n\n# With advanced features\noutput = flash_attention(\n query, key, value,\n causal=True,\n sliding_window=128, # Local attention window\n logits_soft_cap=30.0, # Gemma-2 style soft capping\n attention_mask=mask, # Custom attention pattern\n)\n```\n\n### Custom Configuration\n\n```python\nfrom ejkernel.modules import FlashAttentionConfig\nfrom ejkernel.ops.utils.datacarrier import FwdParams, BwdParams\n\n# Create optimized configuration\nconfig = FlashAttentionConfig(\n fwd_params=FwdParams(\n q_blocksize=256,\n kv_blocksize=256,\n num_warps=8,\n num_stages=2\n ),\n bwd_params=BwdParams(\n q_blocksize=128,\n kv_blocksize=128,\n num_warps=4\n ),\n platform=\"triton\", # Force specific backend\n backend=\"gpu\"\n)\n\noutput = flash_attention(query, key, value, cfg=config)\n```\n\n### Direct Kernel Registry Access\n\n```python\nfrom ejkernel import kernel_registry, Platform, Backend\n\n# Get specific implementation\nkernel = kernel_registry.get(\n algorithm=\"flash_attention\",\n platform=Platform.TRITON,\n backend=Backend.GPU\n)\n\n# Direct execution\noutput = kernel(query, key, value, causal=True)\n```\n\n### Distributed Execution\n\n```python\nimport jax\nfrom jax.sharding import Mesh, PartitionSpec as P\nfrom ejkernel.modules import flash_attention\n\n# Setup mesh for distributed execution\ndevices = jax.devices()\nmesh = Mesh(devices, axis_names=(\"data\", \"model\"))\n\n# Run distributed attention\noutput = flash_attention(\n query, key, value,\n causal=True,\n mesh=mesh,\n in_specs=(P(\"data\", None), P(\"data\", None), P(\"data\", None)),\n out_specs=P(\"data\", None)\n)\n```\n\n## Architecture Overview\n\n### System Design\n\nejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:\n\n```md\n\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502 Public API (modules/) \u2502\n\u2502 Simple functions with sensible defaults \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502 Operations Layer (ops/) \u2502\n\u2502 Configuration management, autotuning, caching \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502 Kernel Registry (kernels/) \u2502\n\u2502 Platform routing, signature validation \u2502\n\u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502 Backend Implementations (kernels/_*) \u2502\n\u2502 Triton, Pallas, XLA, CUDA kernels \u2502\n\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n### Project Structure\n\n```md\nejkernel/\n\u251c\u2500\u2500 kernels/\n\u2502 \u251c\u2500\u2500 _triton/ # GPU kernels via Triton\n\u2502 \u251c\u2500\u2500 _pallas/ # TPU/GPU kernels via Pallas\n\u2502 \u2502 \u251c\u2500\u2500 tpu/ # TPU-specific implementations\n\u2502 \u2502 \u2514\u2500\u2500 gpu/ # GPU Pallas implementations\n\u2502 \u251c\u2500\u2500 _xla/ # Universal XLA implementations\n\u2502 \u2514\u2500\u2500 _cuda/ # Native CUDA kernels\n\u251c\u2500\u2500 modules/\n\u2502 \u2514\u2500\u2500 operations/ # High-level API modules\n\u251c\u2500\u2500 ops/\n\u2502 \u251c\u2500\u2500 config/ # Configuration management\n\u2502 \u251c\u2500\u2500 core/ # Base kernel classes\n\u2502 \u251c\u2500\u2500 execution/ # Execution orchestration\n\u2502 \u2514\u2500\u2500 utils/ # Fingerprinting, utilities\n\u251c\u2500\u2500 xla_utils/ # XLA-specific utilities\n\u2514\u2500\u2500 callib/ # Calibration utilities\n```\n\n### Core Components\n\n#### Kernel Registry\n\nThe registry provides automatic platform-specific kernel selection:\n\n```python\n@kernel_registry.register(\"my_operation\", Platform.TRITON, Backend.GPU, priority=100)\ndef my_operation_gpu(x, y):\n # GPU-optimized implementation\n pass\n\n@kernel_registry.register(\"my_operation\", Platform.XLA, Backend.ANY, priority=50)\ndef my_operation_fallback(x, y):\n # Universal fallback\n pass\n\n# Automatic selection based on available hardware\nimpl = kernel_registry.get(\"my_operation\")\n```\n\n#### Configuration Management\n\nMulti-tier configuration system with intelligent fallback:\n\n```python\nclass ConfigSelectorChain:\n \"\"\"\n Selection hierarchy:\n 1. Override - Explicit user configuration\n 2. Overlay - Temporary context overrides\n 3. Memory Cache - In-memory lookup\n 4. Persistent Cache - Disk-based storage\n 5. Autotune - Performance benchmarking\n 6. Heuristics - Intelligent defaults\n 7. Error - Clear failure message\n \"\"\"\n```\n\n#### Custom VJP System\n\nAll performance-critical kernels implement memory-efficient gradients:\n\n```python\n@jax.custom_vjp\ndef kernel_with_custom_grad(inputs):\n return forward(inputs)\n\ndef kernel_fwd(inputs):\n output, residuals = forward_with_residuals(inputs)\n return output, residuals\n\ndef kernel_bwd(residuals, grad_output):\n return efficient_backward(residuals, grad_output)\n\nkernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)\n```\n\n## Supported Operations\n\n### Attention Mechanisms\n\n| Algorithm | Description | Memory | Key Features |\n|-----------|-------------|--------|--------------|\n| **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |\n| **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |\n| **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |\n| **Block Sparse Attention** | Configurable sparse patterns | O(N\u221aN) | Local+global, custom patterns |\n| **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |\n| **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |\n| **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |\n| **Ragged Attention** | Variable-length sequences | O(N) | Efficient padding, batched inference |\n\n### Other Operations\n\n- **Recurrent Kernels**: Optimized RNN/LSTM/GRU operations\n- **Mean Pooling**: Variable-length sequence aggregation\n- **Grouped MatMul**: Efficient batched matrix operations\n- **Native Sparse**: Block-sparse matrix computations\n\n### Platform Support Matrix\n\n| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) | CUDA |\n|-----------|-------------|--------------|-----------------|------|\n| Flash Attention v2 | \u2713 | \u2713 | \u2713 | Dev |\n| Ring Attention | \u2713 | \u2713 | \u2713 | Dev |\n| Page Attention | \u2713 | \u2713 | \u2713 | Dev |\n| Block Sparse | \u2713 | - | \u2713 | Dev |\n| GLA | \u2713 | Dev | \u2713 | - |\n| Lightning | \u2713 | - | \u2713 | Dev |\n| MLA | \u2713 | Dev | - | - |\n| Ragged Attention | \u2713 | \u2713 | \u2713 | Dev |\n\n\u2713 = Production ready | Dev = Under development | - = Not planned\n\n## Advanced Usage\n\n### Performance Optimization\n\n```python\n# Force autotuning for optimal configuration\nimport os\nos.environ[\"EJKERNEL_AUTOTUNE_POLICY\"] = \"autotune\"\nos.environ[\"EJKERNEL_LOG_AUTOTUNE\"] = \"1\"\n\n# Enable profiling\nos.environ[\"EJKERNEL_OPS_STAMP\"] = \"json\" # Detailed metadata\nos.environ[\"EJKERNEL_OPS_RECORD\"] = \"1\" # Record invocations\n```\n\n### Custom Kernel Development\n\n```python\nfrom ejkernel.ops.core import Kernel\nfrom ejkernel.modules.operations.configs import BaseOperationConfig\n\n@dataclass\nclass MyConfig(BaseOperationConfig):\n param1: int = 128\n param2: float = 0.1\n\nclass MyKernel(Kernel[MyConfig, Array]):\n def __init__(self):\n super().__init__(op_id=\"my_kernel\")\n\n def run(self, x, cfg: MyConfig):\n impl = kernel_registry.get(\"my_kernel\", cfg.platform)\n return impl(x, param1=cfg.param1, param2=cfg.param2)\n\n def heuristic_cfg(self, inv):\n # Return default configuration\n return MyConfig(param1=256)\n\n def candidate_cfgs(self, inv):\n # Return autotuning candidates\n return [MyConfig(param1=p) for p in [64, 128, 256]]\n```\n\n### Integration with Models\n\n```python\nimport flax.linen as nn\n\nclass TransformerBlock(nn.Module):\n num_heads: int = 8\n head_dim: int = 64\n\n @nn.compact\n def __call__(self, x, mask=None):\n # Project to Q, K, V\n q = nn.Dense(self.num_heads * self.head_dim)(x)\n k = nn.Dense(self.num_heads * self.head_dim)(x)\n v = nn.Dense(self.num_heads * self.head_dim)(x)\n\n # Reshape for attention\n shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)\n q, k, v = map(lambda t: t.reshape(shape), (q, k, v))\n\n # Apply ejKernel Flash Attention\n attn_output = flash_attention(\n q, k, v,\n causal=True,\n attention_mask=mask\n )\n\n # Project output\n return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))\n```\n\n## Performance\n\n## Development\n\n### Setting Up Development Environment\n\n```bash\n# Clone repository\ngit clone https://github.com/erfanzar/ejkernel.git\ncd ejkernel\n\n# Create virtual environment\npython -m venv .venv\nsource .venv/bin/activate # On Windows: .venv\\Scripts\\activate\n\n# Install in development mode\npip install -e \".[dev]\"\n\n# Install pre-commit hooks\npre-commit install\n```\n\n### Code Style\n\nThe project uses:\n\n- **black** for code formatting (line length: 121)\n- **ruff** for linting\n- **mypy/pyright** for type checking\n- **pre-commit** for automated checks\n\n### Adding New Kernels\n\n1. **Implement the kernel** in appropriate backend directory:\n\n```python\n# ejkernel/kernels/_triton/my_kernel.py\n@kernel_registry.register(\"my_kernel\", Platform.TRITON, Backend.GPU)\ndef my_kernel_triton(x, config):\n # Implementation\n pass\n```\n\n2 **Create module wrapper**:\n\n```python\n# ejkernel/modules/operations/my_kernel.py\nclass MyKernel(Kernel[MyKernelConfig, Array]):\n # Module implementation\n pass\n```\n\n3 **Add tests**:\n\n```python\n# test/kernels/_triton/test_my_kernel.py\nclass TestMyKernel(unittest.TestCase):\n # Test implementation\n pass\n```\n\n4 **Update documentation**\n\n## Testing\n\n### Running Tests\n\n```bash\n# Run all tests\npython test/run_tests.py\n\n# Platform-specific tests\npython test/run_tests.py --xla # XLA implementations\npython test/run_tests.py --triton # Triton implementations\npython test/run_tests.py --pallas # Pallas implementations\n\n# Cross-platform validation\npython test/run_tests.py --comparison\n\n# Specific test patterns\npython test/run_tests.py -k \"flash_attention\"\npython test/run_tests.py --verbose --failfast\n```\n\n### Test Categories\n\n- **Unit Tests**: Individual component testing\n- **Integration Tests**: End-to-end workflows\n- **Comparison Tests**: Cross-backend consistency\n- **Performance Tests**: Regression detection\n- **Property Tests**: Invariant verification\n\n### Continuous Integration\n\nThe project uses GitHub Actions for CI with tests across:\n\n- Multiple Python versions (3.11, 3.12, 3.13)\n- Multiple platforms (CPU, GPU, TPU)\n- Multiple JAX versions\n\n## Contributing\n\nWe welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.\n\n### Priority Areas\n\n- TPU/Pallas implementations for existing algorithms\n- CUDA native kernels for maximum performance\n- New attention mechanisms from recent papers\n- Performance optimizations and kernel fusion\n- Documentation and examples\n\n### Contribution Process\n\n1. Fork the repository\n2. Create a feature branch\n3. Implement your changes with tests\n4. Ensure all tests pass\n5. Submit a pull request\n\n## Documentation\n\nComprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)\n\n- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation\n- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides\n- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation\n- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis\n\n## Citation\n\nIf you use ejKernel in your research, please cite:\n\n```bibtex\n@software{ejkernel2024,\n author = {Erfan Zare Chavoshi},\n title = {ejKernel: High-Performance JAX Kernels for Deep Learning},\n year = {2024},\n url = {https://github.com/erfanzar/ejkernel},\n note = {Production-grade kernel library with multi-backend support}\n}\n```\n\n## License\n\nejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.\n\n## Acknowledgments\n\nejKernel builds upon excellent work from:\n\n- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs\n- [Triton](https://github.com/openai/triton) - GPU kernel programming language\n- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language\n- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention\n- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning\n\n## Community\n\n- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)\n- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)\n- **Email**: <Erfanzare810@gmail.com>\n\n## Roadmap\n\n### Near Term (Q1 2025)\n\n- Flash Attention 3 implementation\n- Complete CUDA backend\n- Quantized attention (INT8/INT4)\n- Fused operations (LayerNorm+Attention)\n\n### Medium Term (Q2-Q3 2025)\n\n- Speculative decoding support\n- Continuous batching\n- Mamba SSM kernels\n\n### Long Term (Q4 2025+)\n\n- Multi-GPU kernel fusion\n- Automatic kernel selection ML model\n- Custom DSL for kernel development\n- Hardware-agnostic IR\n\n---\n\nejKernel - Production-grade kernels for JAX deep learning\n",
"bugtrack_url": null,
"license": "Apache-2.0",
"summary": "Accelerate, Optimize performance with streamlined training and serving options with JAX.",
"version": "0.0.2",
"project_urls": {
"Documentation": "https://ejkernel.readthedocs.io/en/latest/",
"Homepage": "https://github.com/erfanzar/ejkernel",
"Repository": "https://github.com/erfanzar/ejkernel"
},
"split_keywords": [
"deep learning",
" machine learning",
" jax",
" cuda",
" xla",
" triton",
" pallas"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "a42fab1ef45119ba016f3d96de7b8a07b53d974c74b1e1f7996a79d13ecfce04",
"md5": "3d8db136b2c1ef1d346efc883a843431",
"sha256": "6b5af9c4f07e14e64ba5725a0e36a1c76b93fb5953bd7fc9fd9545d8b93da199"
},
"downloads": -1,
"filename": "ejkernel-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "3d8db136b2c1ef1d346efc883a843431",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<3.14,>=3.11",
"size": 556624,
"upload_time": "2025-10-21T18:38:53",
"upload_time_iso_8601": "2025-10-21T18:38:53.402048Z",
"url": "https://files.pythonhosted.org/packages/a4/2f/ab1ef45119ba016f3d96de7b8a07b53d974c74b1e1f7996a79d13ecfce04/ejkernel-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "a54fca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec",
"md5": "ab964fe6d94062e485760eecaa12c959",
"sha256": "b58b3e9ef0d8b38a12f0c2f6fa342a6e5c08fe4bcec27b4193d7102d22ae3734"
},
"downloads": -1,
"filename": "ejkernel-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "ab964fe6d94062e485760eecaa12c959",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<3.14,>=3.11",
"size": 374778,
"upload_time": "2025-10-21T18:38:55",
"upload_time_iso_8601": "2025-10-21T18:38:55.150828Z",
"url": "https://files.pythonhosted.org/packages/a5/4f/ca260be0be710caa6fe6f38ca3bdc095c826bdf47c8f7b604be322377aec/ejkernel-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-10-21 18:38:55",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "erfanzar",
"github_project": "ejkernel",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "ejkernel"
}