EasyDeL


NameEasyDeL JSON
Version 0.0.67 PyPI version JSON
download
home_pageNone
SummaryAn open-source library to make training faster and more optimized in Jax/Flax
upload_time2024-06-02 17:14:40
maintainerNone
docs_urlNone
authorNone
requires_python>=3.8
licenseApache-2.0
keywords jax torch deep learning machine learning flax xla
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # EasyDeL 🔮

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models.
With a primary focus on Jax/Flax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax
models on TPU/GPU for both serving and training purposes.

## Key Features

1. **Trainers**: EasyDeL offers a range of trainers, including DPOTrainer, ORPOTrainer, SFTTrainer, and VideoCLM
   Trainer, tailored for specific training requirements.

2. **Serving and API Engines**: EasyDeL provides serving and API engines for efficiently using and serving large
   language models (LLMs) in JAX, enabling seamless integration into various applications.

3. **Quantization Support**: EasyDeL supports quantization methods for all models, allowing for efficient inference and
   training.

4. **Bit Operation Support**: EasyDeL supports 8, 6, and 4-bit operations for inference and training in JAX, optimizing
   performance and resource utilization.

5. **Diverse Model Support**: EasyDeL offers a wide range of models in JAX that have never been implemented before, such
   as Falcon, Qwen2, Phi2, Mixtral, Qwen2Moe, Cohere, Dbrx, Phi3, and MPT.

6. **FlashAttention Integration**: EasyDeL integrates FlashAttention in JAX for GPUs and TPUs, enhancing performance and
   efficiency.

7. **Automatic LLM Serving**: EasyDeL enables automatic serving of LLMs with mid and high-level APIs in both JAX and
   PyTorch, simplifying deployment and integration.

8. **LLM Training and Fine-tuning**: EasyDeL provides LLM trainer and fine-tuner capabilities in JAX, allowing for
   efficient training and customization of language models.

9. **Video CLM Training and Fine-tuning**: EasyDeL supports Video CLM trainer and fine-tuner for models such as Falcon,
   Qwen2, Phi2, MPT, Mixtral, Grok-1, and Qwen2Moe, enabling advanced video-related applications.

10. **Performance Optimization**: EasyDeL provides various features to enhance the training process and optimize
    performance, such as LoRA (Low-Rank Adaptation of Large Language Models), RingAttention, FlashAttention, BlockWise
    FFN, and Efficient Attention support (through the FJFormer backbone).

11. **Model Conversion**: EasyDeL supports automatic conversion of models from JAX-EasyDeL to PyTorch-HF and vice versa,
    facilitating seamless integration with different frameworks.

With its comprehensive set of features and tools, EasyDeL aims to streamline and accelerate the training and deployment
of machine learning models, particularly in the domain of large language models and video-related applications.

### Latest News 🔥

- removing *(q,k,v,b,a)_partition_specs and using `PartitionAxis` instead of them.
- Sharding Strategies are changed.
- Now EasyDeL is more Memory efficient Multi-GPUs
- Aya Model is Supported.
- Falcon Model is Updated and now Falcon11B is supported wih flash attention support.
- `pallas_flash` is now available for CPU/GPU/TPU with custom pallas kernel.
- DeepseekV2 Model is Added (beta mood).
- OpenELM Model is Added.
- EasyDeL project structure has changed now you have to import EasyDel as `easydel`.
- `ORPOTrainer` is Added
- Phi3 Model bugs are fixed, Arctic Model is added.

> [!TIP]
>
> use `ed.AttentionModule.test_attentions()` to find the best attention mechanism
> that works for you
> ```python
> import easydel as ed
> ed.AttentionModule.test_attentions()
> ```

## Documentation 💫

> [!IMPORTANT]
> Documents and Examples are ready at [Here](https://easydel.readthedocs.io/en/latest/)
> Please have that in mind that EasyDeL is in the loop of fast-development
> so we might have API changes.

### Hands on Code Kaggle Examples

1. [script](https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example) for mindset of using EasyDeL
   CausalLanguageModelTrainer on kaggle, but you can do much more.
2. [script](https://www.kaggle.com/code/citifer/easydel-sfttrainer-example) SuperVised Finetuning with EasyDeL.

## Serving and Generation

### EasyDeL Generation Pipeline: Your Guide to Text Generation with JAX

The `GenerationPipeline` class in EasyDeL provides a streamlined interface for generating text using pre-trained
language
models within the JAX framework that support token streaming option. This introduction will guide you through its
purpose, potential applications, and basic usage.

#### What it Does:

At its core, the GenerationPipeline takes your input text (provided as `input_ids` and optionally an `attention_mask`)
and
uses a pre-trained language model to predict the most likely following tokens. This process is repeated iteratively,
generating new text one token at a time, until a stopping condition is met (e.g., reaching a maximum length or
encountering a special end-of-sequence token).

**Here's how it works:**

1. **Initialization:**
    - You provide a pre-trained `EasyDeLFlaxPretrainedModel`, typically an instance
      of `EasyDeLFlaxPretrainedModelForCausalLM`.
    - You provide the corresponding model parameters (`params`).
    - A `PreTrainedTokenizer` instance handles tokenization, ensuring compatibility between your text and the model.
    - Optionally, you can customize generation behavior using a `GenerationPipelineConfig` object.

2. **Generating Text:**
    - Call the `generate` method with your input text represented as `input_ids` and an optional `attention_mask`.
    - The pipeline iteratively generates new tokens, extending the input sequence.
    - You can either receive each generated token as it's produced or use a `TextIteratorStreamer` to handle streaming
      output.


**Example Usage:**

```python
import easydel as ed
from transformers import AutoTokenizer
from jax import numpy as jnp

# Load your pre-trained model and tokenizer
model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

# Create a GenerationPipeline
pipeline = ed.GenerationPipeline(model=model, params=params, tokenizer=tokenizer)

# Prepare your input
input_text = "The quick brown fox jumps over the "
tokens = tokenizer(input_text, return_tensors="np", max_length=512, padding="max_length")

# Generate text
outputs = []
pl = 0
for token in pipeline.generate(**tokens):
    outputs.append(token)
    sq = tokenizer.decode(jnp.concatenate(outputs, axis=-1)[0])
    print(sq[pl:],end="")
    pl = len(sq)
```

**Key Points:**

- **Input Format:** The `generate` method expects `input_ids` (numerical representation of tokens) and optionally
  an `attention_mask` to specify relevant input positions.
- **Output Handling:** You can either iterate over individual generated tokens or employ a `TextIteratorStreamer` for
  streaming output.
- **Customization:** Tailor the generation process with options like `max_new_tokens`, `temperature`, `top_k` sampling,
  and more using the `GenerationPipelineConfig`.

The `GenerationPipeline` offers a user-friendly interface to harness the power of EasyDeL's language models for a wide
range of text generation applications.


> [!NOTE]
> you can use `EasyDeLServeEngine` which is a Serve API Engine for production purpose sice that's more stable provide
> versioned
> API and efficient.

## EasyDeLState A Snapshot of Your EasyDeL Model

The `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL
model at a given point in time. Think of it as a snapshot of your model. It includes:

* **Training Progress:**
    * `step`: Tracks the current training step.
* **Model Itself:**
    * `module`:  Holds the actual instance of your EasyDeL model.
    * `module_config`: Stores the model's configuration settings.
    * `module_config_args`:  Keeps track of arguments used to create the configuration (useful for reloading).
    * `apply_fn`:  References the core function that applies your model to data.
* **Learned Parameters:**
    * `params`: Contains the trained weights and biases of your model.
* **Optimizer Information:**
    * `tx`: Stores the optimizer you're using to update the model's parameters (e.g., AdamW).
    * `opt_state`: Keeps track of the optimizer's internal state (this is important for things like momentum in
      optimizers).
    * `tx_init`: Remembers the initial settings used to create the optimizer (again, for reloading purposes).
* **Additional Settings:**
    * `hyperparameters`:  Provides a flexible place to store other hyperparameters related to your model or training
      process.

**Key Capabilities of EasyDeLState:**

* **Initialization (`create`)**: Lets you create a brand new `EasyDeLState` to start training.
* **Loading (`load`, `load_state`, `from_pretrained`)**: Enables you to reload a saved model from a checkpoint file or
  even a pre-trained model from a repository like Hugging Face Hub.
* **Saving (`save_state`)**: Allows you to save your model's current state, including its parameters and optimizer
  state.
* **Optimizer Management (`apply_gradients`, `free_opt_state`, `init_opt_state`)**: Provides methods for updating the
  model's parameters using gradients, releasing optimizer memory, and re-initializing the optimizer if needed.
* **Sharding (`shard_params`)**:  Helps you distribute your model's parameters efficiently across multiple devices (
  important for training large models).
* **PyTorch Conversion (`to_pytorch`)**:  Gives you a way to convert your EasyDeL model to its PyTorch equivalent.

**In Essence:**

`EasyDeLState` streamlines the process of managing, saving, loading, and even converting your EasyDeL models. It ensures
that you can easily work with your models and maintain consistency throughout your machine learning workflow.

## Supervised Fine-Tuning with EasyDeL

EasyDeL supports both DPO and SFT Trainers, so dealing with LLMs in jax is a lot easier right now
let have an example of using Supervised Fine-Tuner in JAX with EasyDeL

```python
from easydel import (
    TrainArguments,
    AutoEasyDeLModelForCausalLM,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    SFTTrainer,
    conversations_formatting_function  # i have added this one for newcomers so if they 
    # don't know what's going on they can use this pre created prompter
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "mistralai/Mistral-7B-Instruct-v0.2"

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 4096
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="SFT-EasyDeL",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
    weight_decay=0.01,
    total_batch_size=32,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


def prompter(sample):
    return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]


train_dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")
trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,  # we don't have eval dataset rn :)
    tokenizer=tokenizer,
    dataset_text_field=None,
    formatting_func=prompter,
    packing=True,
    num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
```

> [!NOTE]
> You Can use Lora too, for DPO, ORPO and SFT Trainers.

## FineTuning

with using EasyDeL FineTuning LLM (CausalLanguageModels) are easy as much as possible with using Jax and Flax
and having the benefit of TPUs for the best speed here's a simple code to use in order to finetune your
own Model

Days Has Been Passed and now using easydel in Jax is way more similar to HF/PyTorch Style
now it's time to finetune our model

```python
from easydel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDeLModelForCausalLM,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDeLOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDeLSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=64,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


def ultra_chat_prompting_process(
        data_chunk
):
    user_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
    ]
    assistant_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
    ]

    prompt = ""

    for uc, ac in zip(user_part, assistant_part):
        prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

    return {"prompt": prompt}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
```

> [!TIP]
> you can then convert it to pytorch for better use I don't recommend jax/flax for hosting models since
> pytorch is better option for gpus

## DPO Fine-tuning

`DPOTrainer` is the new Trainer in EasyDeL, so you might have except some bugs in process but as far as i have tested
everything works just fine, and you can consider it the first DPO Trainer in JAX/Flax let have an example and see how
you can fine-tune your own model with DPOTrainer

> [!TIP]
> In case that you want a better script to learn about `DPOTrainer` you can see examples
> at [here](https://github.com/erfanzar/EasyDeL/blob/main/examples/training/dpo/dpo_training_example.py) which contain
> DPO Tuning a Mixtral model with Intel DPO dataset.

```python
import easydel
from easydel import (
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    DPOTrainer,
    EasyDeLState,
    easystate_to_huggingface_model
)

from datasets import load_dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer, LlamaForCausalLM as module_pt
from jax import numpy as jnp
import jax
from jax.sharding import PartitionSpec
from fjformer import GenerateRNG
from typing import Optional, Dict
from datasets import Dataset

rng_g = GenerateRNG()
api = HfApi()

max_length = 512  # Overall maximum length
max_target_length = 1024  # Maximum Length for target column in Dataset
max_prompt_length = 1024  # Maximum Length for prompt column in Dataset

model_name_or_path = "erfanzar/LinguaMatic-Tiny"
ref_model_name_or_path = "teknium/OpenHermes-2.5-Mistral-7B"
dtype = jnp.bfloat16

sharding_axis_dims = (1, -1, 1, 1)
sharding_axis_names = ("dp", "fsdp", "tp", "sp")


def extract_anthropic_prompt(prompt_and_response):
    """
    Extract the anthropic prompt from a prompt and response pair.
    """
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
    """
    Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt):],
            "rejected": sample["rejected"][len(prompt):],
        }

    return dataset.map(split_prompt_and_responses)


arguments = TrainArguments(
    model_name="EasyDeL-DPO",
    num_train_epochs=5,
    learning_rate=1e-4,
    learning_rate_end=3e-5,
    warmup_steps=200,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.LINEAR,
    weight_decay=0.02,
    total_batch_size=128,
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=sharding_axis_dims,
    fully_sharded_data_parallel=True,
    gradient_accumulation_steps=2,
    dtype=dtype,
    param_dtype=dtype,
    step_start_point=0,
    training_time="7H",
    do_train=True,
    do_eval=True,
    track_memory=False  # Performance boost.
    # You can set other options too or play with them but for now I just stick with these arguments.
)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = get_hh("train", sanity_check=True)
eval_dataset = get_hh("test", sanity_check=True)

state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path=model_name_or_path,
    dtype=dtype,
    param_dtype=dtype,
    init_optimizer_state=False,
    free_optimizer_state=True,
    sharding_axis_dims=sharding_axis_dims,
    sharding_axis_names=sharding_axis_names,
    partition_axis=easydel.PartitionAxis(
        batch_axis=("dp", "fsdp"),
        query_sequence_axis="sp",
        key_sequence_axis="sp",
        head_axis="tp",
        attention_dim_axis=None
    )
)

ref_state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path=ref_model_name_or_path,
    dtype=dtype,
    param_dtype=dtype,
    init_optimizer_state=False,
    free_optimizer_state=True,
    sharding_axis_dims=sharding_axis_dims,
    sharding_axis_names=sharding_axis_names,
    partition_axis=easydel.PartitionAxis(
        batch_axis=("dp", "fsdp"),
        query_sequence_axis="sp",
        key_sequence_axis="sp",
        head_axis="tp",
        attention_dim_axis=None
    )
)

dpo_trainer = DPOTrainer(
    model_state=state,
    ref_model_state=ref_state,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    arguments=arguments,
    max_length=max_length,
    max_target_length=max_target_length,
    max_prompt_length=max_prompt_length,
    ref_model_init_kwargs=None,  # In case that you pass the ref_model_state a string you have to pass this one too
    model_init_kwargs=None,  # In case that you pass the model_state a string you have to pass this one too
    dataset_map_arguments={
        "num_proc": 8,
        "batched": True,
        "batch_size": 100,
    },
    auto_shard_model_state=True,
    auto_shard_ref_model_state=True,
    loss_type="sigmoid",
    data_collator=None,  # Pass None in order to use default data_collector (you can create your own)
)

output = dpo_trainer.train()

easydel_jax_model = output.state  # Here's you EasyDeL Model

with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDeLState.load_state(
            output.checkpoint_path
        ),
        base_huggingface_module=module_pt,
        config=dpo_trainer.model_state.module.config
    )  # Here's you PyTorch Model

model.push_to_hub("<REPO_ID>", private=False)  # Hope you love open-source too :)
tokenizer.push_to_hub("<REPO_ID>", private=False)  # Hope you love open-source too :)
```

now you have trained your first model Using DPOTrainer in JAX with EasyDeL.

> [!TIP]
> The API of EasyDeL DPO Trainer is similar to DPO Trainer in TRL from HuggingFace so that means
> you have freedom and have access to a hackable and changeable code.

## EasyDeLState

EasyDeLState is new and cool feature in EasyDeL and have a lot of options like
storing `Model Parameters`, _Optimizer State,
Model Config, Model Type, Optimizer and Scheduler Configs_

Let see and examples of using EasyDeLState

### Fine-tuning

Fine-tuning from a previous State or a new state

```python
from easydel import (
    AutoEasyDeLConfig,
    EasyDeLState,
    PartitionAxis
)
from transformers import AutoTokenizer
from jax import numpy as jnp, lax
import jax

huggingface_model_repo_id = "REPO_ID"
checkpoint_name = "CKPT_NAME"

state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path=huggingface_model_repo_id,
    filename=checkpoint_name,
    optimizer="adamw",
    scheduler="none",
    tx_init=None,
    device=jax.devices('cpu')[0],  # Offload Device
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=lax.Precision("fastest"),
    sharding_axis_dims=(1, -1, 1, 1),
    # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    partition_axis=PartitionAxis(
        batch_axis=("dp", "fsdp"),
        query_sequence_axis="sp",
        key_sequence_axis="sp",
        head_axis="tp",
        attention_dim_axis=None
    ),
    shard_attention_computation=True,
    input_shape=(1, 1),
    backend=None,
    init_optimizer_state=False,
    free_optimizer_state=True,
    verbose=True,
    state_shard_fns=None,
)

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_model_repo_id,
    trust_remote_code=True
)

max_length = config.max_position_embeddings

configs_to_initialize_model_class = {
    'config': config,
    'dtype': jnp.bfloat16,
    'param_dtype': jnp.bfloat16,
    'input_shape': (8, 8)
}
```

`EasyDeLState` also has `.load_state()` and `.save_state()` with some other usable options like `.free_opt_state()`
which
free optimizer state or `.shard_params()` which shard parameters you can read docs in order to find out more about these
options.

### Converting to Huggingface and Pytorch

Let see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State

```python

from transformers import MistralForCausalLM
from easydel import (
    AutoEasyDeLConfig,
    EasyDeLState,
    easystate_to_huggingface_model
)
import jax

huggingface_model_repo_id = "REPO_ID"

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)
with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDeLState.load_state(
            "PATH_TO_CKPT"
        ),  # You can Pass EasyDeLState here
        base_huggingface_module=MistralForCausalLM,  # type: ignore
        config=config
    )

model = model.half()  # it's a huggingface model now
```

### Other Use Cases

`EasyDeLState` have a general use you can use it everywhere in easydel for example for a stand-alone model
, serve, fine-tuning and many other features, it's up to you to test how creative you are 😇.

## AttentionModule: A Versatile Attention Mechanism Factory

The `AttentionModule` class is designed to simplify the creation and execution of different attention mechanisms within
your EasyDeL models. It provides a unified interface for working with various attention types, allowing you to easily
switch between them and experiment with different configurations.

**Key Features:**

* **Mechanism Selection:** The `attn_mechanism` argument lets you choose the specific attention algorithm you want to
  use (e.g., "vanilla," "flash," "splash," "ring," "cudnn").
* **Sharding and Partitioning:** The class supports advanced JAX sharding techniques to distribute attention
  computations across multiple devices for efficient processing of large models. It handles partitioning of query, key,
  value, bias, and attention weight matrices using `PartitionSpec`.
* **Blockwise Attention:** Enables the use of blockwise attention for increased memory efficiency, especially with long
  sequences.
* **Caching Support:** Facilitates the use of attention caching to speed up inference and generation tasks.
* **Dropout and Determinism:** Allows for applying dropout to attention weights and controlling the deterministic
  behavior of the attention computation.
* **Testing Utility:**  Provides a `test_attentions` method to compare different attention mechanisms in terms of
  accuracy, gradient stability, and computation time.

**How it Works:**

1. **Initialization:**
    - During initialization, you provide the desired `attn_mechanism`, JAX `mesh` for sharding, scaling
      factor (`sm_scale`), number of attention heads, head dimensions, and other configuration parameters.
    - The class automatically sets default values for many parameters based on the chosen attention mechanism and the
      provided EasyDeL configuration (`base_module_class`).
2. **Calling the Module:**
    - When you call the `AttentionModule` object, you pass in the query, key, and value states, along with optional
      parameters like attention masks, biases, and causal flags.
    - The module internally selects the appropriate attention function based on the specified `attn_mechanism`.
    - It performs any necessary sharding and partitioning based on the configured partition specifications.
    - The attention computation is executed, and the attention outputs (and optionally attention weights) are returned.

**Advantages:**

* **Flexibility:**  Allows you to easily switch between different attention mechanisms without major code changes.
* **Efficiency:**  Supports advanced JAX sharding for distributed computation, enabling the handling of large models.

_Flash Attention works on TPU with ease but for gpu there are still some improvements in process._

## EasyDeLXRapTure for layer tuning and LoRA

in case of using LoRA and applying that on the EasyDeL models there are some other things
that you might need to config on your own but a lot of things being handled by EasyDeL so let just jump into an example
for LoRA fine-tuning section and use _EasyDeLXRapTure_ in for mistral models with flash attention example

```python
from flax.core import FrozenDict
from easydel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDeLModelForCausalLM,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    EasyDeLXRapTureConfig
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "mistralai/Mistral-7B-Instruct-v0.1"

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 8196
model_parameters = FrozenDict({"params": params})

dtype = jnp.bfloat16
param_dtype = jnp.bfloat16  # you can change that if you want 

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)

model.config.add_basic_configurations(
    attn_mechanism="flash",  # Using FlashAttention
    block_b=1,
    block_q=1024,
    block_k=1024,
    block_k_major=1024,
)

tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": dtype,
    "param_dtype": param_dtype,
    "input_shape": (1, 1)
}

rapture = EasyDeLXRapTureConfig(
    parameters=model_parameters,
    lora_dim=64,
    fully_fine_tune_parameters=["embed_tokens"],  # Model layer to be fully fine tuned
    lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"],  # LoRA Layer Targets you can pass this to none
    # For only Layer Tuning or transfer learning
    verbose=True
)

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="EasyDeL-Lora-Example",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=1e-4,  # Using higher learning rate is recommended
    learning_rate_end=8e-5,
    optimizer=EasyDeLOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDeLSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=512,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=1,
    loss_re_mat="",
    dtype=dtype,
    param_dtype=param_dtype,
    rapture_config=rapture,
    merge_lora_rapture_parameters=True  # turning this off is still not supported and not recommended to do so
    # What this does ? this will merge the lora parameters with the original model parameters and the end of training
)


def ultra_chat_prompting_process(
        data_chunk
):
    user_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
    ]
    assistant_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
    ]

    prompt = ""

    for uc, ac in zip(user_part, assistant_part):
        prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

    return {"prompt": prompt}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train()  # you should not pass the parameters in Trainer.train anymore when
# you are using LoRA or transfer Learning
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
```

## Contributing

EasyDeL is an open-source project, and contributions are welcome. If you would like to contribute to EasyDeL, please
fork the repository, make your changes, and submit a pull request. The team behind EasyDeL will review your changes and
merge them if they are suitable.

## License 📜

EasyDeL is a Fully Open-Source released under the Apache v2 license. Please see the LICENSE file in the root directory
of this project for
more information.

## Contact

If you have any questions or comments about EasyDeL, you can reach out to me

## Citing EasyDeL 🥶

To cite this repository:

```misc
@misc{Zare Chavoshi_2023,
    title={EasyDeL, an open-source library, is specifically designed to enhance and streamline the training process of machine learning models. It focuses primarily on Jax/Flax and aims to provide convenient and effective solutions for training Flax/Jax Models on TPU/GPU for both Serving and Training purposes.},
    url={https://github.com/erfanzar/EasyDeL},
    journal={EasyDeL Easy and Fast DeepLearning with JAX},
    publisher={Erfan Zare Chavoshi},
    author={Zare Chavoshi, Erfan},
    year={2023}
} 
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "EasyDeL",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": null,
    "keywords": "JAX, Torch, Deep Learning, Machine Learning, Flax, XLA",
    "author": null,
    "author_email": "Erfan Zare Chavoshi <Erfanzare810@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/92/c5/ffad63ed91ad7efc1ff6315e790b6e167d7b62f00ca5fe78c438378104ed/easydel-0.0.67.tar.gz",
    "platform": null,
    "description": "# EasyDeL \ud83d\udd2e\n\nEasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models.\nWith a primary focus on Jax/Flax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax\nmodels on TPU/GPU for both serving and training purposes.\n\n## Key Features\n\n1. **Trainers**: EasyDeL offers a range of trainers, including DPOTrainer, ORPOTrainer, SFTTrainer, and VideoCLM\n   Trainer, tailored for specific training requirements.\n\n2. **Serving and API Engines**: EasyDeL provides serving and API engines for efficiently using and serving large\n   language models (LLMs) in JAX, enabling seamless integration into various applications.\n\n3. **Quantization Support**: EasyDeL supports quantization methods for all models, allowing for efficient inference and\n   training.\n\n4. **Bit Operation Support**: EasyDeL supports 8, 6, and 4-bit operations for inference and training in JAX, optimizing\n   performance and resource utilization.\n\n5. **Diverse Model Support**: EasyDeL offers a wide range of models in JAX that have never been implemented before, such\n   as Falcon, Qwen2, Phi2, Mixtral, Qwen2Moe, Cohere, Dbrx, Phi3, and MPT.\n\n6. **FlashAttention Integration**: EasyDeL integrates FlashAttention in JAX for GPUs and TPUs, enhancing performance and\n   efficiency.\n\n7. **Automatic LLM Serving**: EasyDeL enables automatic serving of LLMs with mid and high-level APIs in both JAX and\n   PyTorch, simplifying deployment and integration.\n\n8. **LLM Training and Fine-tuning**: EasyDeL provides LLM trainer and fine-tuner capabilities in JAX, allowing for\n   efficient training and customization of language models.\n\n9. **Video CLM Training and Fine-tuning**: EasyDeL supports Video CLM trainer and fine-tuner for models such as Falcon,\n   Qwen2, Phi2, MPT, Mixtral, Grok-1, and Qwen2Moe, enabling advanced video-related applications.\n\n10. **Performance Optimization**: EasyDeL provides various features to enhance the training process and optimize\n    performance, such as LoRA (Low-Rank Adaptation of Large Language Models), RingAttention, FlashAttention, BlockWise\n    FFN, and Efficient Attention support (through the FJFormer backbone).\n\n11. **Model Conversion**: EasyDeL supports automatic conversion of models from JAX-EasyDeL to PyTorch-HF and vice versa,\n    facilitating seamless integration with different frameworks.\n\nWith its comprehensive set of features and tools, EasyDeL aims to streamline and accelerate the training and deployment\nof machine learning models, particularly in the domain of large language models and video-related applications.\n\n### Latest News \ud83d\udd25\n\n- removing *(q,k,v,b,a)_partition_specs and using `PartitionAxis` instead of them.\n- Sharding Strategies are changed.\n- Now EasyDeL is more Memory efficient Multi-GPUs\n- Aya Model is Supported.\n- Falcon Model is Updated and now Falcon11B is supported wih flash attention support.\n- `pallas_flash` is now available for CPU/GPU/TPU with custom pallas kernel.\n- DeepseekV2 Model is Added (beta mood).\n- OpenELM Model is Added.\n- EasyDeL project structure has changed now you have to import EasyDel as `easydel`.\n- `ORPOTrainer` is Added\n- Phi3 Model bugs are fixed, Arctic Model is added.\n\n> [!TIP]\n>\n> use `ed.AttentionModule.test_attentions()` to find the best attention mechanism\n> that works for you\n> ```python\n> import easydel as ed\n> ed.AttentionModule.test_attentions()\n> ```\n\n## Documentation \ud83d\udcab\n\n> [!IMPORTANT]\n> Documents and Examples are ready at [Here](https://easydel.readthedocs.io/en/latest/)\n> Please have that in mind that EasyDeL is in the loop of fast-development\n> so we might have API changes.\n\n### Hands on Code Kaggle Examples\n\n1. [script](https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example) for mindset of using EasyDeL\n   CausalLanguageModelTrainer on kaggle, but you can do much more.\n2. [script](https://www.kaggle.com/code/citifer/easydel-sfttrainer-example) SuperVised Finetuning with EasyDeL.\n\n## Serving and Generation\n\n### EasyDeL Generation Pipeline: Your Guide to Text Generation with JAX\n\nThe `GenerationPipeline` class in EasyDeL provides a streamlined interface for generating text using pre-trained\nlanguage\nmodels within the JAX framework that support token streaming option. This introduction will guide you through its\npurpose, potential applications, and basic usage.\n\n#### What it Does:\n\nAt its core, the GenerationPipeline takes your input text (provided as `input_ids` and optionally an `attention_mask`)\nand\nuses a pre-trained language model to predict the most likely following tokens. This process is repeated iteratively,\ngenerating new text one token at a time, until a stopping condition is met (e.g., reaching a maximum length or\nencountering a special end-of-sequence token).\n\n**Here's how it works:**\n\n1. **Initialization:**\n    - You provide a pre-trained `EasyDeLFlaxPretrainedModel`, typically an instance\n      of `EasyDeLFlaxPretrainedModelForCausalLM`.\n    - You provide the corresponding model parameters (`params`).\n    - A `PreTrainedTokenizer` instance handles tokenization, ensuring compatibility between your text and the model.\n    - Optionally, you can customize generation behavior using a `GenerationPipelineConfig` object.\n\n2. **Generating Text:**\n    - Call the `generate` method with your input text represented as `input_ids` and an optional `attention_mask`.\n    - The pipeline iteratively generates new tokens, extending the input sequence.\n    - You can either receive each generated token as it's produced or use a `TextIteratorStreamer` to handle streaming\n      output.\n\n\n**Example Usage:**\n\n```python\nimport easydel as ed\nfrom transformers import AutoTokenizer\nfrom jax import numpy as jnp\n\n# Load your pre-trained model and tokenizer\nmodel, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(...)\ntokenizer = AutoTokenizer.from_pretrained(...)\ntokenizer.padding_side = \"left\"\ntokenizer.truncation_side = \"left\"\n\n# Create a GenerationPipeline\npipeline = ed.GenerationPipeline(model=model, params=params, tokenizer=tokenizer)\n\n# Prepare your input\ninput_text = \"The quick brown fox jumps over the \"\ntokens = tokenizer(input_text, return_tensors=\"np\", max_length=512, padding=\"max_length\")\n\n# Generate text\noutputs = []\npl = 0\nfor token in pipeline.generate(**tokens):\n    outputs.append(token)\n    sq = tokenizer.decode(jnp.concatenate(outputs, axis=-1)[0])\n    print(sq[pl:],end=\"\")\n    pl = len(sq)\n```\n\n**Key Points:**\n\n- **Input Format:** The `generate` method expects `input_ids` (numerical representation of tokens) and optionally\n  an `attention_mask` to specify relevant input positions.\n- **Output Handling:** You can either iterate over individual generated tokens or employ a `TextIteratorStreamer` for\n  streaming output.\n- **Customization:** Tailor the generation process with options like `max_new_tokens`, `temperature`, `top_k` sampling,\n  and more using the `GenerationPipelineConfig`.\n\nThe `GenerationPipeline` offers a user-friendly interface to harness the power of EasyDeL's language models for a wide\nrange of text generation applications.\n\n\n> [!NOTE]\n> you can use `EasyDeLServeEngine` which is a Serve API Engine for production purpose sice that's more stable provide\n> versioned\n> API and efficient.\n\n## EasyDeLState A Snapshot of Your EasyDeL Model\n\nThe `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL\nmodel at a given point in time. Think of it as a snapshot of your model. It includes:\n\n* **Training Progress:**\n    * `step`: Tracks the current training step.\n* **Model Itself:**\n    * `module`:  Holds the actual instance of your EasyDeL model.\n    * `module_config`: Stores the model's configuration settings.\n    * `module_config_args`:  Keeps track of arguments used to create the configuration (useful for reloading).\n    * `apply_fn`:  References the core function that applies your model to data.\n* **Learned Parameters:**\n    * `params`: Contains the trained weights and biases of your model.\n* **Optimizer Information:**\n    * `tx`: Stores the optimizer you're using to update the model's parameters (e.g., AdamW).\n    * `opt_state`: Keeps track of the optimizer's internal state (this is important for things like momentum in\n      optimizers).\n    * `tx_init`: Remembers the initial settings used to create the optimizer (again, for reloading purposes).\n* **Additional Settings:**\n    * `hyperparameters`:  Provides a flexible place to store other hyperparameters related to your model or training\n      process.\n\n**Key Capabilities of EasyDeLState:**\n\n* **Initialization (`create`)**: Lets you create a brand new `EasyDeLState` to start training.\n* **Loading (`load`, `load_state`, `from_pretrained`)**: Enables you to reload a saved model from a checkpoint file or\n  even a pre-trained model from a repository like Hugging Face Hub.\n* **Saving (`save_state`)**: Allows you to save your model's current state, including its parameters and optimizer\n  state.\n* **Optimizer Management (`apply_gradients`, `free_opt_state`, `init_opt_state`)**: Provides methods for updating the\n  model's parameters using gradients, releasing optimizer memory, and re-initializing the optimizer if needed.\n* **Sharding (`shard_params`)**:  Helps you distribute your model's parameters efficiently across multiple devices (\n  important for training large models).\n* **PyTorch Conversion (`to_pytorch`)**:  Gives you a way to convert your EasyDeL model to its PyTorch equivalent.\n\n**In Essence:**\n\n`EasyDeLState` streamlines the process of managing, saving, loading, and even converting your EasyDeL models. It ensures\nthat you can easily work with your models and maintain consistency throughout your machine learning workflow.\n\n## Supervised Fine-Tuning with EasyDeL\n\nEasyDeL supports both DPO and SFT Trainers, so dealing with LLMs in jax is a lot easier right now\nlet have an example of using Supervised Fine-Tuner in JAX with EasyDeL\n\n```python\nfrom easydel import (\n    TrainArguments,\n    AutoEasyDeLModelForCausalLM,\n    EasyDeLOptimizers,\n    EasyDeLSchedulers,\n    EasyDeLGradientCheckPointers,\n    SFTTrainer,\n    conversations_formatting_function  # i have added this one for newcomers so if they \n    # don't know what's going on they can use this pre created prompter\n)\nfrom datasets import load_dataset\nimport flax\nfrom jax import numpy as jnp\nfrom transformers import AutoTokenizer\n\nhuggingface_repo_id_or_path = \"mistralai/Mistral-7B-Instruct-v0.2\"\n\nmodel, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )\n\nmax_length = 4096\ntokenizer = AutoTokenizer.from_pretrained(\n    huggingface_repo_id_or_path,\n    trust_remote_code=True\n)\ntokenizer.pad_token = tokenizer.eos_token\nconfigs_to_initialize_model_class = {\n    \"config\": model.config,\n    \"dtype\": jnp.bfloat16,\n    \"param_dtype\": jnp.bfloat16,\n    \"input_shape\": (1, 1)\n}\n\ntrain_arguments = TrainArguments(\n    model_class=type(model),\n    model_name=\"SFT-EasyDeL\",\n    num_train_epochs=3,\n    configs_to_initialize_model_class=configs_to_initialize_model_class,\n    learning_rate=5e-5,\n    learning_rate_end=1e-6,\n    optimizer=EasyDeLOptimizers.ADAMW,\n    scheduler=EasyDeLSchedulers.WARM_UP_COSINE,\n    weight_decay=0.01,\n    total_batch_size=32,\n    max_training_steps=None,  # None to let trainer Decide\n    do_train=True,\n    do_eval=False,  # it's optional but supported \n    backend=\"tpu\",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu\n    max_length=max_length,  # Note that you have to change this in the model config too\n    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,\n    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)\n    # everything training will be in sequence and model parallel automatic and share data between devices\n    remove_ckpt_after_load=True,\n    gradient_accumulation_steps=8,\n    loss_re_mat=\"\",\n    dtype=jnp.bfloat16\n)\n\n\ndef prompter(sample):\n    return [conversations_formatting_function(tokenizer, messages_field=\"messages\")(sample)]\n\n\ntrain_dataset = load_dataset(\"HuggingFaceH4/deita-10k-v0-sft\", split=\"train_sft\")\ntrainer = SFTTrainer(\n    arguments=train_arguments,\n    train_dataset=train_dataset,\n    eval_dataset=None,  # we don't have eval dataset rn :)\n    tokenizer=tokenizer,\n    dataset_text_field=None,\n    formatting_func=prompter,\n    packing=True,\n    num_of_sequences=max_length,\n)\n\noutput = trainer.train(flax.core.FrozenDict({\"params\": params}))\nprint(f\"Hey ! , here's where your model saved {output.checkpoint_path}\")\n```\n\n> [!NOTE]\n> You Can use Lora too, for DPO, ORPO and SFT Trainers.\n\n## FineTuning\n\nwith using EasyDeL FineTuning LLM (CausalLanguageModels) are easy as much as possible with using Jax and Flax\nand having the benefit of TPUs for the best speed here's a simple code to use in order to finetune your\nown Model\n\nDays Has Been Passed and now using easydel in Jax is way more similar to HF/PyTorch Style\nnow it's time to finetune our model\n\n```python\nfrom easydel import (\n    TrainArguments,\n    CausalLanguageModelTrainer,\n    AutoEasyDeLModelForCausalLM,\n    EasyDeLOptimizers,\n    EasyDeLSchedulers,\n    EasyDeLGradientCheckPointers\n)\nfrom datasets import load_dataset\nimport flax\nfrom jax import numpy as jnp\nfrom transformers import AutoTokenizer\n\nhuggingface_repo_id_or_path = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n\nmodel, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )\n\nmax_length = 2048\ntokenizer = AutoTokenizer.from_pretrained(\n    huggingface_repo_id_or_path,\n    trust_remote_code=True\n)\ntokenizer.pad_token = tokenizer.eos_token\nconfigs_to_initialize_model_class = {\n    \"config\": model.config,\n    \"dtype\": jnp.bfloat16,\n    \"param_dtype\": jnp.bfloat16,\n    \"input_shape\": (1, 1)\n}\n\ntrain_arguments = TrainArguments(\n    model_class=type(model),\n    model_name=\"my_first_model_to_train_using_easydel\",\n    num_train_epochs=3,\n    configs_to_initialize_model_class=configs_to_initialize_model_class,\n    learning_rate=5e-5,\n    learning_rate_end=1e-6,\n    optimizer=EasyDeLOptimizers.ADAMW,  # \"adamw\", \"lion\", \"adafactor\" are supported\n    scheduler=EasyDeLSchedulers.LINEAR,\n    # \"linear\",\"cosine\", \"none\" ,\"warm_up_cosine\" and \"warm_up_linear\"  are supported\n    weight_decay=0.01,\n    total_batch_size=64,\n    max_training_steps=None,  # None to let trainer Decide\n    do_train=True,\n    do_eval=False,  # it's optional but supported \n    backend=\"tpu\",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu\n    max_length=max_length,  # Note that you have to change this in the model config too\n    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,\n    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)\n    # everything training will be in sequence and model parallel automatic and share data between devices\n    remove_ckpt_after_load=True,\n    gradient_accumulation_steps=8,\n    loss_re_mat=\"\",\n    dtype=jnp.bfloat16\n)\n\n\ndef ultra_chat_prompting_process(\n        data_chunk\n):\n    user_part = [\n        chunk[\"content\"] for chunk in data_chunk[\"messages\"] if chunk[\"role\"] == \"user\"\n    ]\n    assistant_part = [\n        chunk[\"content\"] for chunk in data_chunk[\"messages\"] if chunk[\"role\"] == \"assistant\"\n    ]\n\n    prompt = \"\"\n\n    for uc, ac in zip(user_part, assistant_part):\n        prompt += f\"<|user|>\\n{uc}</s>\\n<|assistant|>\\n{ac}</s>\\n\"\n\n    return {\"prompt\": prompt}\n\n\ntokenization_process = lambda data_chunk: tokenizer(\n    data_chunk[\"prompt\"],\n    add_special_tokens=False,\n    max_length=max_length,\n    padding=\"max_length\"\n)\n\ndataset = load_dataset(\"HuggingFaceH4/ultrachat_200k\")\ndataset_train = dataset[\"train_gen\"].map(ultra_chat_prompting_process, num_proc=12)\ndataset_train = dataset_train.map(\n    tokenization_process,\n    num_proc=12,\n    remove_columns=dataset_train.column_names\n)\n\n# you can do the same for evaluation process dataset\n\ntrainer = CausalLanguageModelTrainer(\n    train_arguments,\n    dataset_train,\n    checkpoint_path=None\n)\n\noutput = trainer.train(flax.core.FrozenDict({\"params\": params}))\nprint(f\"Hey ! , here's where your model saved {output.checkpoint_path}\")\n```\n\n> [!TIP]\n> you can then convert it to pytorch for better use I don't recommend jax/flax for hosting models since\n> pytorch is better option for gpus\n\n## DPO Fine-tuning\n\n`DPOTrainer` is the new Trainer in EasyDeL, so you might have except some bugs in process but as far as i have tested\neverything works just fine, and you can consider it the first DPO Trainer in JAX/Flax let have an example and see how\nyou can fine-tune your own model with DPOTrainer\n\n> [!TIP]\n> In case that you want a better script to learn about `DPOTrainer` you can see examples\n> at [here](https://github.com/erfanzar/EasyDeL/blob/main/examples/training/dpo/dpo_training_example.py) which contain\n> DPO Tuning a Mixtral model with Intel DPO dataset.\n\n```python\nimport easydel\nfrom easydel import (\n    TrainArguments,\n    EasyDeLOptimizers,\n    EasyDeLSchedulers,\n    EasyDeLGradientCheckPointers,\n    DPOTrainer,\n    EasyDeLState,\n    easystate_to_huggingface_model\n)\n\nfrom datasets import load_dataset\nfrom huggingface_hub import HfApi\nfrom transformers import AutoTokenizer, LlamaForCausalLM as module_pt\nfrom jax import numpy as jnp\nimport jax\nfrom jax.sharding import PartitionSpec\nfrom fjformer import GenerateRNG\nfrom typing import Optional, Dict\nfrom datasets import Dataset\n\nrng_g = GenerateRNG()\napi = HfApi()\n\nmax_length = 512  # Overall maximum length\nmax_target_length = 1024  # Maximum Length for target column in Dataset\nmax_prompt_length = 1024  # Maximum Length for prompt column in Dataset\n\nmodel_name_or_path = \"erfanzar/LinguaMatic-Tiny\"\nref_model_name_or_path = \"teknium/OpenHermes-2.5-Mistral-7B\"\ndtype = jnp.bfloat16\n\nsharding_axis_dims = (1, -1, 1, 1)\nsharding_axis_names = (\"dp\", \"fsdp\", \"tp\", \"sp\")\n\n\ndef extract_anthropic_prompt(prompt_and_response):\n    \"\"\"\n    Extract the anthropic prompt from a prompt and response pair.\n    \"\"\"\n    search_term = \"\\n\\nAssistant:\"\n    search_term_idx = prompt_and_response.rfind(search_term)\n    assert search_term_idx != -1, f\"Prompt and response does not contain '{search_term}'\"\n    return prompt_and_response[: search_term_idx + len(search_term)]\n\n\ndef get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:\n    \"\"\"\n    Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.\n\n    The dataset is converted to a dictionary with the following structure:\n    {\n        'prompt': List[str],\n        'chosen': List[str],\n        'rejected': List[str],\n    }\n\n    Prompts should be structured as follows:\n      \\n\\nHuman: <prompt>\\n\\nAssistant:\n    Multiple turns are allowed, but the prompt should always start with \\n\\nHuman: and end with \\n\\nAssistant:.\n    \"\"\"\n    dataset = load_dataset(\"Anthropic/hh-rlhf\", split=split, cache_dir=cache_dir)\n    if sanity_check:\n        dataset = dataset.select(range(min(len(dataset), 1000)))\n\n    def split_prompt_and_responses(sample) -> Dict[str, str]:\n        prompt = extract_anthropic_prompt(sample[\"chosen\"])\n        return {\n            \"prompt\": prompt,\n            \"chosen\": sample[\"chosen\"][len(prompt):],\n            \"rejected\": sample[\"rejected\"][len(prompt):],\n        }\n\n    return dataset.map(split_prompt_and_responses)\n\n\narguments = TrainArguments(\n    model_name=\"EasyDeL-DPO\",\n    num_train_epochs=5,\n    learning_rate=1e-4,\n    learning_rate_end=3e-5,\n    warmup_steps=200,\n    optimizer=EasyDeLOptimizers.ADAMW,\n    scheduler=EasyDeLSchedulers.LINEAR,\n    weight_decay=0.02,\n    total_batch_size=128,\n    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,\n    sharding_array=sharding_axis_dims,\n    fully_sharded_data_parallel=True,\n    gradient_accumulation_steps=2,\n    dtype=dtype,\n    param_dtype=dtype,\n    step_start_point=0,\n    training_time=\"7H\",\n    do_train=True,\n    do_eval=True,\n    track_memory=False  # Performance boost.\n    # You can set other options too or play with them but for now I just stick with these arguments.\n)\n\ntokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n\nif tokenizer.pad_token is None:\n    tokenizer.pad_token = tokenizer.eos_token\n\nif tokenizer.pad_token_id is None:\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n\ntrain_dataset = get_hh(\"train\", sanity_check=True)\neval_dataset = get_hh(\"test\", sanity_check=True)\n\nstate = EasyDeLState.from_pretrained(\n    pretrained_model_name_or_path=model_name_or_path,\n    dtype=dtype,\n    param_dtype=dtype,\n    init_optimizer_state=False,\n    free_optimizer_state=True,\n    sharding_axis_dims=sharding_axis_dims,\n    sharding_axis_names=sharding_axis_names,\n    partition_axis=easydel.PartitionAxis(\n        batch_axis=(\"dp\", \"fsdp\"),\n        query_sequence_axis=\"sp\",\n        key_sequence_axis=\"sp\",\n        head_axis=\"tp\",\n        attention_dim_axis=None\n    )\n)\n\nref_state = EasyDeLState.from_pretrained(\n    pretrained_model_name_or_path=ref_model_name_or_path,\n    dtype=dtype,\n    param_dtype=dtype,\n    init_optimizer_state=False,\n    free_optimizer_state=True,\n    sharding_axis_dims=sharding_axis_dims,\n    sharding_axis_names=sharding_axis_names,\n    partition_axis=easydel.PartitionAxis(\n        batch_axis=(\"dp\", \"fsdp\"),\n        query_sequence_axis=\"sp\",\n        key_sequence_axis=\"sp\",\n        head_axis=\"tp\",\n        attention_dim_axis=None\n    )\n)\n\ndpo_trainer = DPOTrainer(\n    model_state=state,\n    ref_model_state=ref_state,\n    beta=0.1,\n    train_dataset=train_dataset,\n    eval_dataset=eval_dataset,\n    tokenizer=tokenizer,\n    arguments=arguments,\n    max_length=max_length,\n    max_target_length=max_target_length,\n    max_prompt_length=max_prompt_length,\n    ref_model_init_kwargs=None,  # In case that you pass the ref_model_state a string you have to pass this one too\n    model_init_kwargs=None,  # In case that you pass the model_state a string you have to pass this one too\n    dataset_map_arguments={\n        \"num_proc\": 8,\n        \"batched\": True,\n        \"batch_size\": 100,\n    },\n    auto_shard_model_state=True,\n    auto_shard_ref_model_state=True,\n    loss_type=\"sigmoid\",\n    data_collator=None,  # Pass None in order to use default data_collector (you can create your own)\n)\n\noutput = dpo_trainer.train()\n\neasydel_jax_model = output.state  # Here's you EasyDeL Model\n\nwith jax.default_device(jax.devices(\"cpu\")[0]):\n    model = easystate_to_huggingface_model(\n        state=EasyDeLState.load_state(\n            output.checkpoint_path\n        ),\n        base_huggingface_module=module_pt,\n        config=dpo_trainer.model_state.module.config\n    )  # Here's you PyTorch Model\n\nmodel.push_to_hub(\"<REPO_ID>\", private=False)  # Hope you love open-source too :)\ntokenizer.push_to_hub(\"<REPO_ID>\", private=False)  # Hope you love open-source too :)\n```\n\nnow you have trained your first model Using DPOTrainer in JAX with EasyDeL.\n\n> [!TIP]\n> The API of EasyDeL DPO Trainer is similar to DPO Trainer in TRL from HuggingFace so that means\n> you have freedom and have access to a hackable and changeable code.\n\n## EasyDeLState\n\nEasyDeLState is new and cool feature in EasyDeL and have a lot of options like\nstoring `Model Parameters`, _Optimizer State,\nModel Config, Model Type, Optimizer and Scheduler Configs_\n\nLet see and examples of using EasyDeLState\n\n### Fine-tuning\n\nFine-tuning from a previous State or a new state\n\n```python\nfrom easydel import (\n    AutoEasyDeLConfig,\n    EasyDeLState,\n    PartitionAxis\n)\nfrom transformers import AutoTokenizer\nfrom jax import numpy as jnp, lax\nimport jax\n\nhuggingface_model_repo_id = \"REPO_ID\"\ncheckpoint_name = \"CKPT_NAME\"\n\nstate = EasyDeLState.from_pretrained(\n    pretrained_model_name_or_path=huggingface_model_repo_id,\n    filename=checkpoint_name,\n    optimizer=\"adamw\",\n    scheduler=\"none\",\n    tx_init=None,\n    device=jax.devices('cpu')[0],  # Offload Device\n    dtype=jnp.bfloat16,\n    param_dtype=jnp.bfloat16,\n    precision=lax.Precision(\"fastest\"),\n    sharding_axis_dims=(1, -1, 1, 1),\n    # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)\n    # everything training will be in sequence and model parallel automatic and share data between devices\n    sharding_axis_names=(\"dp\", \"fsdp\", \"tp\", \"sp\"),\n    partition_axis=PartitionAxis(\n        batch_axis=(\"dp\", \"fsdp\"),\n        query_sequence_axis=\"sp\",\n        key_sequence_axis=\"sp\",\n        head_axis=\"tp\",\n        attention_dim_axis=None\n    ),\n    shard_attention_computation=True,\n    input_shape=(1, 1),\n    backend=None,\n    init_optimizer_state=False,\n    free_optimizer_state=True,\n    verbose=True,\n    state_shard_fns=None,\n)\n\nconfig = AutoEasyDeLConfig.from_pretrained(\n    huggingface_model_repo_id\n)\n\ntokenizer = AutoTokenizer.from_pretrained(\n    huggingface_model_repo_id,\n    trust_remote_code=True\n)\n\nmax_length = config.max_position_embeddings\n\nconfigs_to_initialize_model_class = {\n    'config': config,\n    'dtype': jnp.bfloat16,\n    'param_dtype': jnp.bfloat16,\n    'input_shape': (8, 8)\n}\n```\n\n`EasyDeLState` also has `.load_state()` and `.save_state()` with some other usable options like `.free_opt_state()`\nwhich\nfree optimizer state or `.shard_params()` which shard parameters you can read docs in order to find out more about these\noptions.\n\n### Converting to Huggingface and Pytorch\n\nLet see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State\n\n```python\n\nfrom transformers import MistralForCausalLM\nfrom easydel import (\n    AutoEasyDeLConfig,\n    EasyDeLState,\n    easystate_to_huggingface_model\n)\nimport jax\n\nhuggingface_model_repo_id = \"REPO_ID\"\n\nconfig = AutoEasyDeLConfig.from_pretrained(\n    huggingface_model_repo_id\n)\nwith jax.default_device(jax.devices(\"cpu\")[0]):\n    model = easystate_to_huggingface_model(\n        state=EasyDeLState.load_state(\n            \"PATH_TO_CKPT\"\n        ),  # You can Pass EasyDeLState here\n        base_huggingface_module=MistralForCausalLM,  # type: ignore\n        config=config\n    )\n\nmodel = model.half()  # it's a huggingface model now\n```\n\n### Other Use Cases\n\n`EasyDeLState` have a general use you can use it everywhere in easydel for example for a stand-alone model\n, serve, fine-tuning and many other features, it's up to you to test how creative you are \ud83d\ude07.\n\n## AttentionModule: A Versatile Attention Mechanism Factory\n\nThe `AttentionModule` class is designed to simplify the creation and execution of different attention mechanisms within\nyour EasyDeL models. It provides a unified interface for working with various attention types, allowing you to easily\nswitch between them and experiment with different configurations.\n\n**Key Features:**\n\n* **Mechanism Selection:** The `attn_mechanism` argument lets you choose the specific attention algorithm you want to\n  use (e.g., \"vanilla,\" \"flash,\" \"splash,\" \"ring,\" \"cudnn\").\n* **Sharding and Partitioning:** The class supports advanced JAX sharding techniques to distribute attention\n  computations across multiple devices for efficient processing of large models. It handles partitioning of query, key,\n  value, bias, and attention weight matrices using `PartitionSpec`.\n* **Blockwise Attention:** Enables the use of blockwise attention for increased memory efficiency, especially with long\n  sequences.\n* **Caching Support:** Facilitates the use of attention caching to speed up inference and generation tasks.\n* **Dropout and Determinism:** Allows for applying dropout to attention weights and controlling the deterministic\n  behavior of the attention computation.\n* **Testing Utility:**  Provides a `test_attentions` method to compare different attention mechanisms in terms of\n  accuracy, gradient stability, and computation time.\n\n**How it Works:**\n\n1. **Initialization:**\n    - During initialization, you provide the desired `attn_mechanism`, JAX `mesh` for sharding, scaling\n      factor (`sm_scale`), number of attention heads, head dimensions, and other configuration parameters.\n    - The class automatically sets default values for many parameters based on the chosen attention mechanism and the\n      provided EasyDeL configuration (`base_module_class`).\n2. **Calling the Module:**\n    - When you call the `AttentionModule` object, you pass in the query, key, and value states, along with optional\n      parameters like attention masks, biases, and causal flags.\n    - The module internally selects the appropriate attention function based on the specified `attn_mechanism`.\n    - It performs any necessary sharding and partitioning based on the configured partition specifications.\n    - The attention computation is executed, and the attention outputs (and optionally attention weights) are returned.\n\n**Advantages:**\n\n* **Flexibility:**  Allows you to easily switch between different attention mechanisms without major code changes.\n* **Efficiency:**  Supports advanced JAX sharding for distributed computation, enabling the handling of large models.\n\n_Flash Attention works on TPU with ease but for gpu there are still some improvements in process._\n\n## EasyDeLXRapTure for layer tuning and LoRA\n\nin case of using LoRA and applying that on the EasyDeL models there are some other things\nthat you might need to config on your own but a lot of things being handled by EasyDeL so let just jump into an example\nfor LoRA fine-tuning section and use _EasyDeLXRapTure_ in for mistral models with flash attention example\n\n```python\nfrom flax.core import FrozenDict\nfrom easydel import (\n    TrainArguments,\n    CausalLanguageModelTrainer,\n    AutoEasyDeLModelForCausalLM,\n    EasyDeLOptimizers,\n    EasyDeLSchedulers,\n    EasyDeLGradientCheckPointers,\n    EasyDeLXRapTureConfig\n)\nfrom datasets import load_dataset\nimport flax\nfrom jax import numpy as jnp\nfrom transformers import AutoTokenizer\n\nhuggingface_repo_id_or_path = \"mistralai/Mistral-7B-Instruct-v0.1\"\n\nmodel, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )\n\nmax_length = 8196\nmodel_parameters = FrozenDict({\"params\": params})\n\ndtype = jnp.bfloat16\nparam_dtype = jnp.bfloat16  # you can change that if you want \n\ntokenizer = AutoTokenizer.from_pretrained(\n    huggingface_repo_id_or_path,\n    trust_remote_code=True\n)\n\nmodel.config.add_basic_configurations(\n    attn_mechanism=\"flash\",  # Using FlashAttention\n    block_b=1,\n    block_q=1024,\n    block_k=1024,\n    block_k_major=1024,\n)\n\ntokenizer.pad_token = tokenizer.eos_token\nconfigs_to_initialize_model_class = {\n    \"config\": model.config,\n    \"dtype\": dtype,\n    \"param_dtype\": param_dtype,\n    \"input_shape\": (1, 1)\n}\n\nrapture = EasyDeLXRapTureConfig(\n    parameters=model_parameters,\n    lora_dim=64,\n    fully_fine_tune_parameters=[\"embed_tokens\"],  # Model layer to be fully fine tuned\n    lora_fine_tune_parameters=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],  # LoRA Layer Targets you can pass this to none\n    # For only Layer Tuning or transfer learning\n    verbose=True\n)\n\ntrain_arguments = TrainArguments(\n    model_class=type(model),\n    model_name=\"EasyDeL-Lora-Example\",\n    num_train_epochs=3,\n    configs_to_initialize_model_class=configs_to_initialize_model_class,\n    learning_rate=1e-4,  # Using higher learning rate is recommended\n    learning_rate_end=8e-5,\n    optimizer=EasyDeLOptimizers.ADAMW,  # \"adamw\", \"lion\", \"adafactor\" are supported\n    scheduler=EasyDeLSchedulers.LINEAR,\n    # \"linear\",\"cosine\", \"none\" ,\"warm_up_cosine\" and \"warm_up_linear\"  are supported\n    weight_decay=0.01,\n    total_batch_size=512,\n    max_training_steps=None,  # None to let trainer Decide\n    do_train=True,\n    do_eval=False,  # it's optional but supported \n    backend=\"tpu\",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu\n    max_length=max_length,  # Note that you have to change this in the model config too\n    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,\n    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)\n    # everything training will be in sequence and model parallel automatic and share data between devices\n    remove_ckpt_after_load=True,\n    gradient_accumulation_steps=1,\n    loss_re_mat=\"\",\n    dtype=dtype,\n    param_dtype=param_dtype,\n    rapture_config=rapture,\n    merge_lora_rapture_parameters=True  # turning this off is still not supported and not recommended to do so\n    # What this does ? this will merge the lora parameters with the original model parameters and the end of training\n)\n\n\ndef ultra_chat_prompting_process(\n        data_chunk\n):\n    user_part = [\n        chunk[\"content\"] for chunk in data_chunk[\"messages\"] if chunk[\"role\"] == \"user\"\n    ]\n    assistant_part = [\n        chunk[\"content\"] for chunk in data_chunk[\"messages\"] if chunk[\"role\"] == \"assistant\"\n    ]\n\n    prompt = \"\"\n\n    for uc, ac in zip(user_part, assistant_part):\n        prompt += f\"<|user|>\\n{uc}</s>\\n<|assistant|>\\n{ac}</s>\\n\"\n\n    return {\"prompt\": prompt}\n\n\ntokenization_process = lambda data_chunk: tokenizer(\n    data_chunk[\"prompt\"],\n    add_special_tokens=False,\n    max_length=max_length,\n    padding=\"max_length\"\n)\n\ndataset = load_dataset(\"HuggingFaceH4/ultrachat_200k\")\ndataset_train = dataset[\"train_gen\"].map(ultra_chat_prompting_process, num_proc=12)\ndataset_train = dataset_train.map(\n    tokenization_process,\n    num_proc=12,\n    remove_columns=dataset_train.column_names\n)\n\n# you can do the same for evaluation process dataset\n\ntrainer = CausalLanguageModelTrainer(\n    train_arguments,\n    dataset_train,\n    checkpoint_path=None\n)\n\noutput = trainer.train()  # you should not pass the parameters in Trainer.train anymore when\n# you are using LoRA or transfer Learning\nprint(f\"Hey ! , here's where your model saved {output.checkpoint_path}\")\n```\n\n## Contributing\n\nEasyDeL is an open-source project, and contributions are welcome. If you would like to contribute to EasyDeL, please\nfork the repository, make your changes, and submit a pull request. The team behind EasyDeL will review your changes and\nmerge them if they are suitable.\n\n## License \ud83d\udcdc\n\nEasyDeL is a Fully Open-Source released under the Apache v2 license. Please see the LICENSE file in the root directory\nof this project for\nmore information.\n\n## Contact\n\nIf you have any questions or comments about EasyDeL, you can reach out to me\n\n## Citing EasyDeL \ud83e\udd76\n\nTo cite this repository:\n\n```misc\n@misc{Zare Chavoshi_2023,\n    title={EasyDeL, an open-source library, is specifically designed to enhance and streamline the training process of machine learning models. It focuses primarily on Jax/Flax and aims to provide convenient and effective solutions for training Flax/Jax Models on TPU/GPU for both Serving and Training purposes.},\n    url={https://github.com/erfanzar/EasyDeL},\n    journal={EasyDeL Easy and Fast DeepLearning with JAX},\n    publisher={Erfan Zare Chavoshi},\n    author={Zare Chavoshi, Erfan},\n    year={2023}\n} \n```\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "An open-source library to make training faster and more optimized in Jax/Flax",
    "version": "0.0.67",
    "project_urls": {
        "Documentation": "https://erfanzar.github.io/EasyDeL",
        "Homepage": "https://github.com/erfanzar/EasyDeL",
        "Issues": "https://github.com/erfanzar/EasyDeL/issues"
    },
    "split_keywords": [
        "jax",
        " torch",
        " deep learning",
        " machine learning",
        " flax",
        " xla"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "3e7a04d20698c85e949ca1bac9b301ae066eed2a5210b0979f44743e454a694c",
                "md5": "e2db2253d8b4f876fb9c430520dcb911",
                "sha256": "fabd448d307434c81886d191a122f1045a67fa29b7defc06e508277ae37f8311"
            },
            "downloads": -1,
            "filename": "EasyDeL-0.0.67-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "e2db2253d8b4f876fb9c430520dcb911",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 483264,
            "upload_time": "2024-06-02T17:12:58",
            "upload_time_iso_8601": "2024-06-02T17:12:58.034087Z",
            "url": "https://files.pythonhosted.org/packages/3e/7a/04d20698c85e949ca1bac9b301ae066eed2a5210b0979f44743e454a694c/EasyDeL-0.0.67-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "92c5ffad63ed91ad7efc1ff6315e790b6e167d7b62f00ca5fe78c438378104ed",
                "md5": "dc8ad185379bc994ed79f2851a5582e6",
                "sha256": "1a7843bbc9dbc7dfaa004b474d6431e6922425a02110d66fad79616bd5605f19"
            },
            "downloads": -1,
            "filename": "easydel-0.0.67.tar.gz",
            "has_sig": false,
            "md5_digest": "dc8ad185379bc994ed79f2851a5582e6",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 424987,
            "upload_time": "2024-06-02T17:14:40",
            "upload_time_iso_8601": "2024-06-02T17:14:40.785081Z",
            "url": "https://files.pythonhosted.org/packages/92/c5/ffad63ed91ad7efc1ff6315e790b6e167d7b62f00ca5fe78c438378104ed/easydel-0.0.67.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-06-02 17:14:40",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "erfanzar",
    "github_project": "EasyDeL",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [],
    "lcname": "easydel"
}
        
Elapsed time: 0.26722s