# eformer (EasyDel Former)
[](https://opensource.org/licenses/Apache-2.0)
[](https://www.python.org/)
[](https://github.com/google/jax)
[](https://badge.fury.io/py/eformer)
**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.
## Project Structure Overview
The library is organized into several core modules:
- **`aparser`**: Advanced argument parsing utilities with dataclass integration
- **`callib`**: Custom function calling and Triton kernel integration
- **`common_types`**: Shared type definitions and sharding constants
- **`escale`**: Distributed sharding and parallelism utilities
- **`executor`**: Execution management and hardware-specific optimizations
- **`jaximus`**: Custom PyTree implementations and structured array utilities
- **`mpric`**: Mixed precision training and dynamic scaling infrastructure
- **`ops`**: Optimized operations including Flash Attention and quantization
- **`optimizers`**: Flexible optimizer configuration and factory patterns
- **`pytree`**: Enhanced tree manipulation and transformation utilities
## Key Features
### 1. Mixed Precision Training (`mpric`)
Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.
### 2. Distributed Sharding (`escale`)
Tools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:
- Data Parallelism (`DP`)
- Fully Sharded Data Parallel (`FSDP`)
- Tensor Parallelism (`TP`)
- Expert Parallelism (`EP`)
- Sequence Parallelism (`SP`)
### 3. Custom PyTrees (`jaximus`)
Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox, providing flexible data structures for your models.
### 4. Triton Integration (`callib`)
Custom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.
### 5. Optimizer Factory
A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.
### 6. Optimized Operations (`ops`)
- Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations
- 8-bit and NF4 quantization for efficient model deployment
- Additional optimized operations under active development
## API Documentation
For detailed API references and usage examples, see:
- [Argument Parser (`aparser`)](docs/api_docs/aparser.rst)
- [Triton Integration (`callib`)](docs/api_docs/callib.rst)
- [Sharding Utilities (`escale`)](docs/api_docs/escale.rst)
- [Execution Management (`executor`)](docs/api_docs/executor.rst)
- [Mixed Precision Infrastructure (`mpric`)](docs/api_docs/mpric.rst)
- [Custom Operations (`ops`)](docs/api_docs/ops.rst)
## Installation
You can install `eformer` via pip:
```bash
pip install eformer
```
## Getting Started
### Mixed Precision Handler with mpric
```python
from eformer.mpric import PrecisionHandler
# Create a handler with float8 compute precision
handler = PrecisionHandler(
policy="p=f32,c=f8_e4m3,o=f32", # params in f32, compute in float8, output in f32
use_dynamic_scale=True
)
```
### Custom PyTree Implementation
```python
import jax
from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0
class Array8B(ArrayValue):
scale: jax.Array
weight: jax.Array
def __init__(self, array: jax.Array):
self.weight, self.scale = quantize_row_q8_0(array)
def materialize(self):
return dequantize_row_q8_0(self.weight, self.scale)
array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
qarray = Array8B(array)
```
## Contributing
We welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.
## License
This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.
Raw data
{
"_id": null,
"home_page": null,
"name": "eformer",
"maintainer": null,
"docs_url": null,
"requires_python": "<3.14,>=3.10",
"maintainer_email": null,
"keywords": "JAX, Deep Learning, Machine Learning, Flax, XLA, EasyDeL",
"author": "Erfan Zare Chavoshi",
"author_email": "Erfan Zare Chavoshi <Erfanzare810@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/fa/09/ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087/eformer-0.0.45.tar.gz",
"platform": null,
"description": "# eformer (EasyDel Former)\n\n[](https://opensource.org/licenses/Apache-2.0)\n[](https://www.python.org/)\n[](https://github.com/google/jax)\n[](https://badge.fury.io/py/eformer)\n\n**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.\n\n## Project Structure Overview\n\nThe library is organized into several core modules:\n\n- **`aparser`**: Advanced argument parsing utilities with dataclass integration\n- **`callib`**: Custom function calling and Triton kernel integration\n- **`common_types`**: Shared type definitions and sharding constants\n- **`escale`**: Distributed sharding and parallelism utilities\n- **`executor`**: Execution management and hardware-specific optimizations\n- **`jaximus`**: Custom PyTree implementations and structured array utilities\n- **`mpric`**: Mixed precision training and dynamic scaling infrastructure\n- **`ops`**: Optimized operations including Flash Attention and quantization\n- **`optimizers`**: Flexible optimizer configuration and factory patterns\n- **`pytree`**: Enhanced tree manipulation and transformation utilities\n\n## Key Features\n\n### 1. Mixed Precision Training (`mpric`)\n\nAdvanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.\n\n### 2. Distributed Sharding (`escale`)\n\nTools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:\n\n- Data Parallelism (`DP`)\n- Fully Sharded Data Parallel (`FSDP`)\n- Tensor Parallelism (`TP`)\n- Expert Parallelism (`EP`)\n- Sequence Parallelism (`SP`)\n\n### 3. Custom PyTrees (`jaximus`)\n\nEnhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox, providing flexible data structures for your models.\n\n### 4. Triton Integration (`callib`)\n\nCustom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.\n\n### 5. Optimizer Factory\n\nA flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.\n\n### 6. Optimized Operations (`ops`)\n\n- Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations\n- 8-bit and NF4 quantization for efficient model deployment\n- Additional optimized operations under active development\n\n## API Documentation\n\nFor detailed API references and usage examples, see:\n\n- [Argument Parser (`aparser`)](docs/api_docs/aparser.rst)\n- [Triton Integration (`callib`)](docs/api_docs/callib.rst)\n- [Sharding Utilities (`escale`)](docs/api_docs/escale.rst)\n- [Execution Management (`executor`)](docs/api_docs/executor.rst)\n- [Mixed Precision Infrastructure (`mpric`)](docs/api_docs/mpric.rst)\n- [Custom Operations (`ops`)](docs/api_docs/ops.rst)\n\n## Installation\n\nYou can install `eformer` via pip:\n\n```bash\npip install eformer\n```\n\n## Getting Started\n\n### Mixed Precision Handler with mpric\n\n```python\nfrom eformer.mpric import PrecisionHandler\n\n# Create a handler with float8 compute precision\nhandler = PrecisionHandler(\n policy=\"p=f32,c=f8_e4m3,o=f32\", # params in f32, compute in float8, output in f32\n use_dynamic_scale=True\n)\n```\n\n### Custom PyTree Implementation\n\n```python\nimport jax\nfrom eformer.jaximus import ArrayValue, implicit\nfrom eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0\n\nclass Array8B(ArrayValue):\n scale: jax.Array\n weight: jax.Array\n \n def __init__(self, array: jax.Array):\n self.weight, self.scale = quantize_row_q8_0(array)\n \n def materialize(self):\n return dequantize_row_q8_0(self.weight, self.scale)\n\narray = jax.random.normal(jax.random.key(0), (256, 64), \"f2\")\nqarray = Array8B(array)\n```\n\n## Contributing\n\nWe welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.\n\n## License\n\nThis project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.\n",
"bugtrack_url": null,
"license": "Apache-2.0",
"summary": "(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX",
"version": "0.0.45",
"project_urls": {
"Documentation": "https://erfanzar.github.io/eformer",
"Homepage": "https://github.com/erfanzar/eformer",
"Repository": "https://github.com/erfanzar/eformer"
},
"split_keywords": [
"jax",
" deep learning",
" machine learning",
" flax",
" xla",
" easydel"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "7971c58f7bae8b274f8888af249fb4a05bec5d3d37041441b4ca7d210f968cf6",
"md5": "6e591458f149220d14510bc761de930c",
"sha256": "d9b11ed9b67419211755c502b30b45e72c3ac1407305addfa395271dccc84467"
},
"downloads": -1,
"filename": "eformer-0.0.45-py3-none-any.whl",
"has_sig": false,
"md5_digest": "6e591458f149220d14510bc761de930c",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<3.14,>=3.10",
"size": 138601,
"upload_time": "2025-07-20T22:24:37",
"upload_time_iso_8601": "2025-07-20T22:24:37.377529Z",
"url": "https://files.pythonhosted.org/packages/79/71/c58f7bae8b274f8888af249fb4a05bec5d3d37041441b4ca7d210f968cf6/eformer-0.0.45-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "fa09ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087",
"md5": "4cd4e5807fd3f4a933d48d6fe7e62c7b",
"sha256": "8d26b4b2c1dfb5ad62c63b43f7db7ae64ce4b93914d20d43675b80fb645331ba"
},
"downloads": -1,
"filename": "eformer-0.0.45.tar.gz",
"has_sig": false,
"md5_digest": "4cd4e5807fd3f4a933d48d6fe7e62c7b",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<3.14,>=3.10",
"size": 99351,
"upload_time": "2025-07-20T22:24:39",
"upload_time_iso_8601": "2025-07-20T22:24:39.065488Z",
"url": "https://files.pythonhosted.org/packages/fa/09/ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087/eformer-0.0.45.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-20 22:24:39",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "erfanzar",
"github_project": "eformer",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "eformer"
}