eformer


Nameeformer JSON
Version 0.0.45 PyPI version JSON
download
home_pageNone
Summary(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX
upload_time2025-07-20 22:24:39
maintainerNone
docs_urlNone
authorErfan Zare Chavoshi
requires_python<3.14,>=3.10
licenseApache-2.0
keywords jax deep learning machine learning flax xla easydel
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # eformer (EasyDel Former)

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue)](https://www.python.org/)
[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)
[![PyPI version](https://badge.fury.io/py/eformer.svg)](https://badge.fury.io/py/eformer)

**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.

## Project Structure Overview

The library is organized into several core modules:

- **`aparser`**: Advanced argument parsing utilities with dataclass integration
- **`callib`**: Custom function calling and Triton kernel integration
- **`common_types`**: Shared type definitions and sharding constants
- **`escale`**: Distributed sharding and parallelism utilities
- **`executor`**: Execution management and hardware-specific optimizations
- **`jaximus`**: Custom PyTree implementations and structured array utilities
- **`mpric`**: Mixed precision training and dynamic scaling infrastructure
- **`ops`**: Optimized operations including Flash Attention and quantization
- **`optimizers`**: Flexible optimizer configuration and factory patterns
- **`pytree`**: Enhanced tree manipulation and transformation utilities

## Key Features

### 1. Mixed Precision Training (`mpric`)

Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.

### 2. Distributed Sharding (`escale`)

Tools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:

- Data Parallelism (`DP`)
- Fully Sharded Data Parallel (`FSDP`)
- Tensor Parallelism (`TP`)
- Expert Parallelism (`EP`)
- Sequence Parallelism (`SP`)

### 3. Custom PyTrees (`jaximus`)

Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox, providing flexible data structures for your models.

### 4. Triton Integration (`callib`)

Custom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.

### 5. Optimizer Factory

A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.

### 6. Optimized Operations (`ops`)

- Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations
- 8-bit and NF4 quantization for efficient model deployment
- Additional optimized operations under active development

## API Documentation

For detailed API references and usage examples, see:

- [Argument Parser (`aparser`)](docs/api_docs/aparser.rst)
- [Triton Integration (`callib`)](docs/api_docs/callib.rst)
- [Sharding Utilities (`escale`)](docs/api_docs/escale.rst)
- [Execution Management (`executor`)](docs/api_docs/executor.rst)
- [Mixed Precision Infrastructure (`mpric`)](docs/api_docs/mpric.rst)
- [Custom Operations (`ops`)](docs/api_docs/ops.rst)

## Installation

You can install `eformer` via pip:

```bash
pip install eformer
```

## Getting Started

### Mixed Precision Handler with mpric

```python
from eformer.mpric import PrecisionHandler

# Create a handler with float8 compute precision
handler = PrecisionHandler(
    policy="p=f32,c=f8_e4m3,o=f32",  # params in f32, compute in float8, output in f32
    use_dynamic_scale=True
)
```

### Custom PyTree Implementation

```python
import jax
from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0

class Array8B(ArrayValue):
    scale: jax.Array
    weight: jax.Array
    
    def __init__(self, array: jax.Array):
        self.weight, self.scale = quantize_row_q8_0(array)
    
    def materialize(self):
        return dequantize_row_q8_0(self.weight, self.scale)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
qarray = Array8B(array)
```

## Contributing

We welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.

## License

This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "eformer",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<3.14,>=3.10",
    "maintainer_email": null,
    "keywords": "JAX, Deep Learning, Machine Learning, Flax, XLA, EasyDeL",
    "author": "Erfan Zare Chavoshi",
    "author_email": "Erfan Zare Chavoshi <Erfanzare810@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/fa/09/ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087/eformer-0.0.45.tar.gz",
    "platform": null,
    "description": "# eformer (EasyDel Former)\n\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Python](https://img.shields.io/badge/Python-3.10%2B-blue)](https://www.python.org/)\n[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)\n[![PyPI version](https://badge.fury.io/py/eformer.svg)](https://badge.fury.io/py/eformer)\n\n**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.\n\n## Project Structure Overview\n\nThe library is organized into several core modules:\n\n- **`aparser`**: Advanced argument parsing utilities with dataclass integration\n- **`callib`**: Custom function calling and Triton kernel integration\n- **`common_types`**: Shared type definitions and sharding constants\n- **`escale`**: Distributed sharding and parallelism utilities\n- **`executor`**: Execution management and hardware-specific optimizations\n- **`jaximus`**: Custom PyTree implementations and structured array utilities\n- **`mpric`**: Mixed precision training and dynamic scaling infrastructure\n- **`ops`**: Optimized operations including Flash Attention and quantization\n- **`optimizers`**: Flexible optimizer configuration and factory patterns\n- **`pytree`**: Enhanced tree manipulation and transformation utilities\n\n## Key Features\n\n### 1. Mixed Precision Training (`mpric`)\n\nAdvanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.\n\n### 2. Distributed Sharding (`escale`)\n\nTools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:\n\n- Data Parallelism (`DP`)\n- Fully Sharded Data Parallel (`FSDP`)\n- Tensor Parallelism (`TP`)\n- Expert Parallelism (`EP`)\n- Sequence Parallelism (`SP`)\n\n### 3. Custom PyTrees (`jaximus`)\n\nEnhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox, providing flexible data structures for your models.\n\n### 4. Triton Integration (`callib`)\n\nCustom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.\n\n### 5. Optimizer Factory\n\nA flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.\n\n### 6. Optimized Operations (`ops`)\n\n- Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations\n- 8-bit and NF4 quantization for efficient model deployment\n- Additional optimized operations under active development\n\n## API Documentation\n\nFor detailed API references and usage examples, see:\n\n- [Argument Parser (`aparser`)](docs/api_docs/aparser.rst)\n- [Triton Integration (`callib`)](docs/api_docs/callib.rst)\n- [Sharding Utilities (`escale`)](docs/api_docs/escale.rst)\n- [Execution Management (`executor`)](docs/api_docs/executor.rst)\n- [Mixed Precision Infrastructure (`mpric`)](docs/api_docs/mpric.rst)\n- [Custom Operations (`ops`)](docs/api_docs/ops.rst)\n\n## Installation\n\nYou can install `eformer` via pip:\n\n```bash\npip install eformer\n```\n\n## Getting Started\n\n### Mixed Precision Handler with mpric\n\n```python\nfrom eformer.mpric import PrecisionHandler\n\n# Create a handler with float8 compute precision\nhandler = PrecisionHandler(\n    policy=\"p=f32,c=f8_e4m3,o=f32\",  # params in f32, compute in float8, output in f32\n    use_dynamic_scale=True\n)\n```\n\n### Custom PyTree Implementation\n\n```python\nimport jax\nfrom eformer.jaximus import ArrayValue, implicit\nfrom eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0\n\nclass Array8B(ArrayValue):\n    scale: jax.Array\n    weight: jax.Array\n    \n    def __init__(self, array: jax.Array):\n        self.weight, self.scale = quantize_row_q8_0(array)\n    \n    def materialize(self):\n        return dequantize_row_q8_0(self.weight, self.scale)\n\narray = jax.random.normal(jax.random.key(0), (256, 64), \"f2\")\nqarray = Array8B(array)\n```\n\n## Contributing\n\nWe welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.\n\n## License\n\nThis project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX",
    "version": "0.0.45",
    "project_urls": {
        "Documentation": "https://erfanzar.github.io/eformer",
        "Homepage": "https://github.com/erfanzar/eformer",
        "Repository": "https://github.com/erfanzar/eformer"
    },
    "split_keywords": [
        "jax",
        " deep learning",
        " machine learning",
        " flax",
        " xla",
        " easydel"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "7971c58f7bae8b274f8888af249fb4a05bec5d3d37041441b4ca7d210f968cf6",
                "md5": "6e591458f149220d14510bc761de930c",
                "sha256": "d9b11ed9b67419211755c502b30b45e72c3ac1407305addfa395271dccc84467"
            },
            "downloads": -1,
            "filename": "eformer-0.0.45-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "6e591458f149220d14510bc761de930c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<3.14,>=3.10",
            "size": 138601,
            "upload_time": "2025-07-20T22:24:37",
            "upload_time_iso_8601": "2025-07-20T22:24:37.377529Z",
            "url": "https://files.pythonhosted.org/packages/79/71/c58f7bae8b274f8888af249fb4a05bec5d3d37041441b4ca7d210f968cf6/eformer-0.0.45-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "fa09ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087",
                "md5": "4cd4e5807fd3f4a933d48d6fe7e62c7b",
                "sha256": "8d26b4b2c1dfb5ad62c63b43f7db7ae64ce4b93914d20d43675b80fb645331ba"
            },
            "downloads": -1,
            "filename": "eformer-0.0.45.tar.gz",
            "has_sig": false,
            "md5_digest": "4cd4e5807fd3f4a933d48d6fe7e62c7b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<3.14,>=3.10",
            "size": 99351,
            "upload_time": "2025-07-20T22:24:39",
            "upload_time_iso_8601": "2025-07-20T22:24:39.065488Z",
            "url": "https://files.pythonhosted.org/packages/fa/09/ff50b488cb1966c9d8bda4bddb5224902fb707ffdf1f8cb0a963255dc087/eformer-0.0.45.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-20 22:24:39",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "erfanzar",
    "github_project": "eformer",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "eformer"
}
        
Elapsed time: 1.20803s