trl


Nametrl JSON
Version 0.15.0 PyPI version JSON
download
home_pagehttps://github.com/huggingface/trl
SummaryTrain transformer language models with reinforcement learning.
upload_time2025-02-13 14:37:40
maintainerNone
docs_urlNone
authorLeandro von Werra
requires_python>=3.9
licenseApache 2.0
keywords ppo transformers huggingface gpt2 language modeling rlhf
VCS
bugtrack_url
requirements accelerate datasets rich transformers
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # TRL - Transformer Reinforcement Learning

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
</div>

<hr> <br>

<h3 align="center">
    <p>A comprehensive library to post-train foundation models</p>
</h3>

<p align="center">
    <a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
    <a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
    <a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
</p>

## Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

## Highlights

- **Efficient and scalable**: 
    - Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    - Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    - Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.

- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.

- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.

- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.

## Installation

### Python Package

Install the library using `pip`:

```bash
pip install trl
```

### From source

If you want to use the latest features before an official release, you can install TRL from source:

```bash
pip install git+https://github.com/huggingface/trl.git
```

### Repository

If you want to use the examples you can clone the repository with the following command:

```bash
git clone https://github.com/huggingface/trl.git
```

## Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI: 

**SFT:**

```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara \
    --output_dir Qwen2.5-0.5B-SFT
```

**DPO:**

```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name argilla/Capybara-Preferences \
    --output_dir Qwen2.5-0.5B-DPO 
```

**Chat:**

```bash
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
```

Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.

## How to use

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

### `SFTTrainer`

Here is a basic example of how to use the `SFTTrainer`:

```python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()
```

### `RewardTrainer`

Here is a basic example of how to use the `RewardTrainer`:

```python
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()
```

### `GRPOTrainer`

`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).

```python
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Dummy reward function: rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()
```

### `DPOTrainer`

`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
```

## Development

If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:

```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
```

## Citation

```bibtex
@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}
```

## License

This repository's source code is available under the [Apache-2.0 License](LICENSE).

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/huggingface/trl",
    "name": "trl",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "ppo, transformers, huggingface, gpt2, language modeling, rlhf",
    "author": "Leandro von Werra",
    "author_email": "leandro.vonwerra@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/77/00/12f9f691ffbe06efc27cc6781f4fe94c7cf7f1254c684b42ed789b216de2/trl-0.15.0.tar.gz",
    "platform": null,
    "description": "# TRL - Transformer Reinforcement Learning\n\n<div style=\"text-align: center\">\n<img src=\"https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png\" alt=\"TRL Banner\">\n</div>\n\n<hr> <br>\n\n<h3 align=\"center\">\n    <p>A comprehensive library to post-train foundation models</p>\n</h3>\n\n<p align=\"center\">\n    <a href=\"https://github.com/huggingface/trl/blob/main/LICENSE\"><img alt=\"License\" src=\"https://img.shields.io/github/license/huggingface/trl.svg?color=blue\"></a>\n    <a href=\"https://huggingface.co/docs/trl/index\"><img alt=\"Documentation\" src=\"https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online\"></a>\n    <a href=\"https://github.com/huggingface/trl/releases\"><img alt=\"GitHub release\" src=\"https://img.shields.io/github/release/huggingface/trl.svg\"></a>\n</p>\n\n## Overview\n\nTRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [\ud83e\udd17 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.\n\n## Highlights\n\n- **Efficient and scalable**: \n    - Leverages [\ud83e\udd17 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.\n    - Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.\n    - Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.\n\n- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.\n\n- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.\n\n- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.\n\n## Installation\n\n### Python Package\n\nInstall the library using `pip`:\n\n```bash\npip install trl\n```\n\n### From source\n\nIf you want to use the latest features before an official release, you can install TRL from source:\n\n```bash\npip install git+https://github.com/huggingface/trl.git\n```\n\n### Repository\n\nIf you want to use the examples you can clone the repository with the following command:\n\n```bash\ngit clone https://github.com/huggingface/trl.git\n```\n\n## Command Line Interface (CLI)\n\nYou can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI: \n\n**SFT:**\n\n```bash\ntrl sft --model_name_or_path Qwen/Qwen2.5-0.5B \\\n    --dataset_name trl-lib/Capybara \\\n    --output_dir Qwen2.5-0.5B-SFT\n```\n\n**DPO:**\n\n```bash\ntrl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \\\n    --dataset_name argilla/Capybara-Preferences \\\n    --output_dir Qwen2.5-0.5B-DPO \n```\n\n**Chat:**\n\n```bash\ntrl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct\n```\n\nRead more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.\n\n## How to use\n\nFor more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the \ud83e\udd17 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.\n\n### `SFTTrainer`\n\nHere is a basic example of how to use the `SFTTrainer`:\n\n```python\nfrom trl import SFTConfig, SFTTrainer\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"trl-lib/Capybara\", split=\"train\")\n\ntraining_args = SFTConfig(output_dir=\"Qwen/Qwen2.5-0.5B-SFT\")\ntrainer = SFTTrainer(\n    args=training_args,\n    model=\"Qwen/Qwen2.5-0.5B\",\n    train_dataset=dataset,\n)\ntrainer.train()\n```\n\n### `RewardTrainer`\n\nHere is a basic example of how to use the `RewardTrainer`:\n\n```python\nfrom trl import RewardConfig, RewardTrainer\nfrom datasets import load_dataset\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\ntokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B-Instruct\")\nmodel = AutoModelForSequenceClassification.from_pretrained(\n    \"Qwen/Qwen2.5-0.5B-Instruct\", num_labels=1\n)\nmodel.config.pad_token_id = tokenizer.pad_token_id\n\ndataset = load_dataset(\"trl-lib/ultrafeedback_binarized\", split=\"train\")\n\ntraining_args = RewardConfig(output_dir=\"Qwen2.5-0.5B-Reward\", per_device_train_batch_size=2)\ntrainer = RewardTrainer(\n    args=training_args,\n    model=model,\n    processing_class=tokenizer,\n    train_dataset=dataset,\n)\ntrainer.train()\n```\n\n### `GRPOTrainer`\n\n`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\n\n```python\nfrom datasets import load_dataset\nfrom trl import GRPOConfig, GRPOTrainer\n\ndataset = load_dataset(\"trl-lib/tldr\", split=\"train\")\n\n# Dummy reward function: rewards completions that are close to 20 characters\ndef reward_len(completions, **kwargs):\n    return [-abs(20 - len(completion)) for completion in completions]\n\ntraining_args = GRPOConfig(output_dir=\"Qwen2-0.5B-GRPO\", logging_steps=10)\ntrainer = GRPOTrainer(\n    model=\"Qwen/Qwen2-0.5B-Instruct\",\n    reward_funcs=reward_len,\n    args=training_args,\n    train_dataset=dataset,\n)\ntrainer.train()\n```\n\n### `DPOTrainer`\n\n`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:\n\n```python\nfrom datasets import load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import DPOConfig, DPOTrainer\n\nmodel = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-0.5B-Instruct\")\ntokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B-Instruct\")\ndataset = load_dataset(\"trl-lib/ultrafeedback_binarized\", split=\"train\")\ntraining_args = DPOConfig(output_dir=\"Qwen2.5-0.5B-DPO\")\ntrainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)\ntrainer.train()\n```\n\n## Development\n\nIf you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:\n\n```bash\ngit clone https://github.com/huggingface/trl.git\ncd trl/\npip install -e .[dev]\n```\n\n## Citation\n\n```bibtex\n@misc{vonwerra2022trl,\n  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou\u00e9dec},\n  title = {TRL: Transformer Reinforcement Learning},\n  year = {2020},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/huggingface/trl}}\n}\n```\n\n## License\n\nThis repository's source code is available under the [Apache-2.0 License](LICENSE).\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Train transformer language models with reinforcement learning.",
    "version": "0.15.0",
    "project_urls": {
        "Homepage": "https://github.com/huggingface/trl"
    },
    "split_keywords": [
        "ppo",
        " transformers",
        " huggingface",
        " gpt2",
        " language modeling",
        " rlhf"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "dc3c51dab98b5d6636ba8de20551c81c2f76a503a9b33e969a0360e17ef79f9f",
                "md5": "f2478f3c61863c61b869d3d3fa15efce",
                "sha256": "5e23c84e9773c77ed41ae0d8ae03eee5f9224973000aa0d272a25c727402bdb6"
            },
            "downloads": -1,
            "filename": "trl-0.15.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "f2478f3c61863c61b869d3d3fa15efce",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 318261,
            "upload_time": "2025-02-13T14:37:39",
            "upload_time_iso_8601": "2025-02-13T14:37:39.318648Z",
            "url": "https://files.pythonhosted.org/packages/dc/3c/51dab98b5d6636ba8de20551c81c2f76a503a9b33e969a0360e17ef79f9f/trl-0.15.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "770012f9f691ffbe06efc27cc6781f4fe94c7cf7f1254c684b42ed789b216de2",
                "md5": "f31d768f6a4ba1cafeeaacbd92b80ed6",
                "sha256": "16fc2d906e09428192744062938c6135fa42c467203f0154d8b7de8c915b1930"
            },
            "downloads": -1,
            "filename": "trl-0.15.0.tar.gz",
            "has_sig": false,
            "md5_digest": "f31d768f6a4ba1cafeeaacbd92b80ed6",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 332923,
            "upload_time": "2025-02-13T14:37:40",
            "upload_time_iso_8601": "2025-02-13T14:37:40.756912Z",
            "url": "https://files.pythonhosted.org/packages/77/00/12f9f691ffbe06efc27cc6781f4fe94c7cf7f1254c684b42ed789b216de2/trl-0.15.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-02-13 14:37:40",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "huggingface",
    "github_project": "trl",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "accelerate",
            "specs": []
        },
        {
            "name": "datasets",
            "specs": []
        },
        {
            "name": "rich",
            "specs": []
        },
        {
            "name": "transformers",
            "specs": [
                [
                    ">=",
                    "4.46.0"
                ]
            ]
        }
    ],
    "lcname": "trl"
}
        
Elapsed time: 0.50258s