eformer


Nameeformer JSON
Version 0.0.14 PyPI version JSON
download
home_pagehttps://github.com/erfanzar/eformer
Summary(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX
upload_time2025-02-26 19:12:42
maintainerNone
docs_urlNone
authorErfan Zare Chavoshi
requires_python<4.0,>=3.10
licenseApache-2.0
keywords jax deep learning machine learning flax xla
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # eformer (EasyDel Former)

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.8%2B-blue)](https://www.python.org/)
[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)

**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.

## Features

- **Mixed Precision Training (`mpric`)**: Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling.
- **Sharding Utilities (`escale`)**: Tools for efficient sharding and distributed computation in JAX.
- **Custom PyTrees (`jaximus`)**: Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox.
- **Custom Calling (`callib`)**: A tool for custom function calls and direct integration with Triton kernels in JAX.
- **Optimizer Factory**: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
- **Custom Operations and Kernels**:
  - Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
  - 8-bit and NF4 quantization for efficient model.
  - Many others to be added.
- **Quantization Support**: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.

## Installation

You can install `eformer` via pip:

```bash
pip install eformer
```

## Quick Start

### Mixed Precision Handler with mpric

```python
from eformer.mpric import PrecisionHandler

# Create a handler with float8 compute precision
handler = PrecisionHandler(
    policy="p=f32,c=f8_e4m3,o=f32",  # params in f32, compute in float8, output in f32
    use_dynamic_scale=True
)
```

### Customizing Arrays With ArrayValue

```python
import jax

from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import (
    dequantize_row_q8_0,
    quantize_row_q8_0,
)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")


class Array8B(ArrayValue):
    scale: jax.Array
    weight: jax.Array

    def __init__(self, array: jax.Array):
        self.weight, self.scale = quantize_row_q8_0(array)

    def materialize(self):
        return dequantize_row_q8_0(self.weight, self.scale)


qarray = Array8B(array)


@jax.jit
@implicit
def sqrt(x):
    return jax.numpy.sqrt(x)


print(sqrt(qarray))
print(qarray)
```

### Optimizer Factory

```python
from eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig

# Create an AdamW optimizer with a cosine scheduler
scheduler_config = SchedulerConfig(scheduler_type="cosine", learning_rate=1e-3, steps=1000)
optimizer, scheduler = OptimizerFactory.create("adamw", scheduler_config, AdamWConfig())
```

### Quantization

```python
from eformer.quantization import Array8B, ArrayNF4

# Quantize an array to 8-bit
qarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), "f2"))

# Quantize an array to NF4
n4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), "f2"), 64)
```

### Advanced Mixed Precision Configuration

```python
from eformer.mpric import Policy, LossScaleConfig

# Create a custom precision policy
policy = Policy(
    param_dtype=jnp.float32,
    compute_dtype=jnp.bfloat16,
    output_dtype=jnp.float32
)

# Configure loss scaling
loss_config = LossScaleConfig(
    initial_scale=2**15,
    growth_interval=2000,
    scale_factor=2,
    min_scale=1.0
)

# Create handler with custom configuration
handler = PrecisionHandler(
    policy=policy,
    use_dynamic_scale=True,
    loss_scale_config=loss_config
)
```

## Contributing

We welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.

## License

This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/erfanzar/eformer",
    "name": "eformer",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.10",
    "maintainer_email": null,
    "keywords": "JAX, Deep Learning, Machine Learning, Flax, XLA",
    "author": "Erfan Zare Chavoshi",
    "author_email": "Erfanzare810@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/82/ef/99cb0b73a4a27646b779c601acbcb5368b85de32e9a94faa2e4f346733c4/eformer-0.0.14.tar.gz",
    "platform": null,
    "description": "# eformer (EasyDel Former)\n\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Python](https://img.shields.io/badge/Python-3.8%2B-blue)](https://www.python.org/)\n[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)\n\n**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.\n\n## Features\n\n- **Mixed Precision Training (`mpric`)**: Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling.\n- **Sharding Utilities (`escale`)**: Tools for efficient sharding and distributed computation in JAX.\n- **Custom PyTrees (`jaximus`)**: Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox.\n- **Custom Calling (`callib`)**: A tool for custom function calls and direct integration with Triton kernels in JAX.\n- **Optimizer Factory**: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.\n- **Custom Operations and Kernels**:\n  - Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).\n  - 8-bit and NF4 quantization for efficient model.\n  - Many others to be added.\n- **Quantization Support**: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.\n\n## Installation\n\nYou can install `eformer` via pip:\n\n```bash\npip install eformer\n```\n\n## Quick Start\n\n### Mixed Precision Handler with mpric\n\n```python\nfrom eformer.mpric import PrecisionHandler\n\n# Create a handler with float8 compute precision\nhandler = PrecisionHandler(\n    policy=\"p=f32,c=f8_e4m3,o=f32\",  # params in f32, compute in float8, output in f32\n    use_dynamic_scale=True\n)\n```\n\n### Customizing Arrays With ArrayValue\n\n```python\nimport jax\n\nfrom eformer.jaximus import ArrayValue, implicit\nfrom eformer.ops.quantization.quantization_functions import (\n    dequantize_row_q8_0,\n    quantize_row_q8_0,\n)\n\narray = jax.random.normal(jax.random.key(0), (256, 64), \"f2\")\n\n\nclass Array8B(ArrayValue):\n    scale: jax.Array\n    weight: jax.Array\n\n    def __init__(self, array: jax.Array):\n        self.weight, self.scale = quantize_row_q8_0(array)\n\n    def materialize(self):\n        return dequantize_row_q8_0(self.weight, self.scale)\n\n\nqarray = Array8B(array)\n\n\n@jax.jit\n@implicit\ndef sqrt(x):\n    return jax.numpy.sqrt(x)\n\n\nprint(sqrt(qarray))\nprint(qarray)\n```\n\n### Optimizer Factory\n\n```python\nfrom eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig\n\n# Create an AdamW optimizer with a cosine scheduler\nscheduler_config = SchedulerConfig(scheduler_type=\"cosine\", learning_rate=1e-3, steps=1000)\noptimizer, scheduler = OptimizerFactory.create(\"adamw\", scheduler_config, AdamWConfig())\n```\n\n### Quantization\n\n```python\nfrom eformer.quantization import Array8B, ArrayNF4\n\n# Quantize an array to 8-bit\nqarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), \"f2\"))\n\n# Quantize an array to NF4\nn4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), \"f2\"), 64)\n```\n\n### Advanced Mixed Precision Configuration\n\n```python\nfrom eformer.mpric import Policy, LossScaleConfig\n\n# Create a custom precision policy\npolicy = Policy(\n    param_dtype=jnp.float32,\n    compute_dtype=jnp.bfloat16,\n    output_dtype=jnp.float32\n)\n\n# Configure loss scaling\nloss_config = LossScaleConfig(\n    initial_scale=2**15,\n    growth_interval=2000,\n    scale_factor=2,\n    min_scale=1.0\n)\n\n# Create handler with custom configuration\nhandler = PrecisionHandler(\n    policy=policy,\n    use_dynamic_scale=True,\n    loss_scale_config=loss_config\n)\n```\n\n## Contributing\n\nWe welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.\n\n## License\n\nThis project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX",
    "version": "0.0.14",
    "project_urls": {
        "Documentation": "https://erfanzar.github.io/eformer",
        "Homepage": "https://github.com/erfanzar/eformer",
        "Repository": "https://github.com/erfanzar/eformer"
    },
    "split_keywords": [
        "jax",
        " deep learning",
        " machine learning",
        " flax",
        " xla"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "bc041d14c5e1ebaca2ee4886a11d40cc911580150c54408269d47c67bf5d27d4",
                "md5": "93e27cae0351252af107306069cfd37c",
                "sha256": "b992bd47f866c6633971f9aa4fc843fcbe386458f0b402ba49cc9f9a19503d11"
            },
            "downloads": -1,
            "filename": "eformer-0.0.14-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "93e27cae0351252af107306069cfd37c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.10",
            "size": 109518,
            "upload_time": "2025-02-26T19:12:39",
            "upload_time_iso_8601": "2025-02-26T19:12:39.749594Z",
            "url": "https://files.pythonhosted.org/packages/bc/04/1d14c5e1ebaca2ee4886a11d40cc911580150c54408269d47c67bf5d27d4/eformer-0.0.14-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "82ef99cb0b73a4a27646b779c601acbcb5368b85de32e9a94faa2e4f346733c4",
                "md5": "88a268bf01f5832fac74e42befd36fce",
                "sha256": "31e356ed6b6e5f284d6d18ec9825641e7bc95a32f1d21e091a2e33ba5670ba56"
            },
            "downloads": -1,
            "filename": "eformer-0.0.14.tar.gz",
            "has_sig": false,
            "md5_digest": "88a268bf01f5832fac74e42befd36fce",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.10",
            "size": 78276,
            "upload_time": "2025-02-26T19:12:42",
            "upload_time_iso_8601": "2025-02-26T19:12:42.040316Z",
            "url": "https://files.pythonhosted.org/packages/82/ef/99cb0b73a4a27646b779c601acbcb5368b85de32e9a94faa2e4f346733c4/eformer-0.0.14.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-02-26 19:12:42",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "erfanzar",
    "github_project": "eformer",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "eformer"
}
        
Elapsed time: 1.39716s