trl


Nametrl JSON
Version 0.8.4 PyPI version JSON
download
home_pagehttps://github.com/huggingface/trl
SummaryTrain transformer language models with reinforcement learning.
upload_time2024-04-17 15:16:50
maintainerNone
docs_urlNone
authorLeandro von Werra
requires_python>=3.7
licenseApache 2.0
keywords ppo transformers huggingface gpt2 language modeling rlhf
VCS
bugtrack_url
requirements datasets torch tqdm transformers accelerate peft tyro
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
</div>

# TRL - Transformer Reinforcement Learning
> Full stack library to fine-tune and align large language models.

<p align="center">
    <a href="https://github.com/huggingface/trl/blob/main/LICENSE">
        <img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
    </a>
    <a href="https://huggingface.co/docs/trl/index">
        <img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
    </a>
    <a href="https://github.com/huggingface/trl/releases">
        <img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
    </a>
</p>


## What is it?

The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO). 

The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.


## Highlights

- **`Efficient and scalable`**: 
    - [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
    - [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
    - [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).

## Installation

### Python package
Install the library with `pip`:
```bash
pip install trl
```

### From source
If you want to use the latest features before an official release you can install from source:
```bash
pip install git+https://github.com/huggingface/trl.git
```

### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
git clone https://github.com/huggingface/trl.git
```

## Command Line Interface (CLI)

You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI: 

**SFT:**

```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```

**DPO:**

```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf 
```

**Chat:**

```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```

Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.

## How to use

For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.

### `SFTTrainer`

This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.

```python
# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()
```

### `RewardTrainer`

This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()
```

### `PPOTrainer`

This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```python
# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```

### `DPOTrainer`

`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()
```

## Development

If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:

```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
```

## References

### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].

### Direct Preference Optimization
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://arxiv.org/pdf/2305.18290.pdf), [code](https://github.com/eric-mitchell/direct-preference-optimization)]


## Citation

```bibtex
@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/huggingface/trl",
    "name": "trl",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": null,
    "keywords": "ppo, transformers, huggingface, gpt2, language modeling, rlhf",
    "author": "Leandro von Werra",
    "author_email": "leandro.vonwerra@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/b2/b3/0a2222ce0e019d347e745934c4e0e8131b81fe91045f28a605e3065333a2/trl-0.8.4.tar.gz",
    "platform": null,
    "description": "<div style=\"text-align: center\">\n<img src=\"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png\">\n</div>\n\n# TRL - Transformer Reinforcement Learning\n> Full stack library to fine-tune and align large language models.\n\n<p align=\"center\">\n    <a href=\"https://github.com/huggingface/trl/blob/main/LICENSE\">\n        <img alt=\"License\" src=\"https://img.shields.io/github/license/huggingface/trl.svg?color=blue\">\n    </a>\n    <a href=\"https://huggingface.co/docs/trl/index\">\n        <img alt=\"Documentation\" src=\"https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online\">\n    </a>\n    <a href=\"https://github.com/huggingface/trl/releases\">\n        <img alt=\"GitHub release\" src=\"https://img.shields.io/github/release/huggingface/trl.svg\">\n    </a>\n</p>\n\n\n## What is it?\n\nThe `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO). \n\nThe library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.\n\n\n## Highlights\n\n- **`Efficient and scalable`**: \n    - [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.\n    - [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.\n    - [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.\n- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.\n- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).\n- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.\n- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).\n\n## Installation\n\n### Python package\nInstall the library with `pip`:\n```bash\npip install trl\n```\n\n### From source\nIf you want to use the latest features before an official release you can install from source:\n```bash\npip install git+https://github.com/huggingface/trl.git\n```\n\n### Repository\nIf you want to use the examples you can clone the repository with the following command:\n```bash\ngit clone https://github.com/huggingface/trl.git\n```\n\n## Command Line Interface (CLI)\n\nYou can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI: \n\n**SFT:**\n\n```bash\ntrl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb\n```\n\n**DPO:**\n\n```bash\ntrl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf \n```\n\n**Chat:**\n\n```bash\ntrl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat\n```\n\nRead more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.\n\n## How to use\n\nFor more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.\n\n### `SFTTrainer`\n\nThis is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.\n\n```python\n# imports\nfrom datasets import load_dataset\nfrom trl import SFTTrainer\n\n# get dataset\ndataset = load_dataset(\"imdb\", split=\"train\")\n\n# get trainer\ntrainer = SFTTrainer(\n    \"facebook/opt-350m\",\n    train_dataset=dataset,\n    dataset_text_field=\"text\",\n    max_seq_length=512,\n)\n\n# train\ntrainer.train()\n```\n\n### `RewardTrainer`\n\nThis is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.\n\n```python\n# imports\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\nfrom trl import RewardTrainer\n\n# load model and dataset - dataset needs to be in a specific format\nmodel = AutoModelForSequenceClassification.from_pretrained(\"gpt2\", num_labels=1)\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n\n...\n\n# load trainer\ntrainer = RewardTrainer(\n    model=model,\n    tokenizer=tokenizer,\n    train_dataset=dataset,\n)\n\n# train\ntrainer.train()\n```\n\n### `PPOTrainer`\n\nThis is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.\n\n```python\n# imports\nimport torch\nfrom transformers import AutoTokenizer\nfrom trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model\nfrom trl.core import respond_to_batch\n\n# get models\nmodel = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')\nmodel_ref = create_reference_model(model)\n\ntokenizer = AutoTokenizer.from_pretrained('gpt2')\ntokenizer.pad_token = tokenizer.eos_token\n\n# initialize trainer\nppo_config = PPOConfig(batch_size=1, mini_batch_size=1)\n\n# encode a query\nquery_txt = \"This morning I went to the \"\nquery_tensor = tokenizer.encode(query_txt, return_tensors=\"pt\")\n\n# get model response\nresponse_tensor  = respond_to_batch(model, query_tensor)\n\n# create a ppo trainer\nppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)\n\n# define a reward for response\n# (this could be any reward such as human feedback or output from another model)\nreward = [torch.tensor(1.0)]\n\n# train model for one step with ppo\ntrain_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)\n```\n\n### `DPOTrainer`\n\n`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.\n\n```python\n# imports\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import DPOTrainer\n\n# load model and dataset - dataset needs to be in a specific format\nmodel = AutoModelForCausalLM.from_pretrained(\"gpt2\")\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n\n...\n\n# load trainer\ntrainer = DPOTrainer(\n    model=model,\n    tokenizer=tokenizer,\n    train_dataset=dataset,\n)\n\n# train\ntrainer.train()\n```\n\n## Development\n\nIf you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:\n\n```bash\ngit clone https://github.com/huggingface/trl.git\ncd trl/\nmake dev\n```\n\n## References\n\n### Proximal Policy Optimisation\nThe PPO implementation largely follows the structure introduced in the paper **\"Fine-Tuning Language Models from Human Preferences\"** by D. Ziegler et al. \\[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].\n\n### Direct Preference Optimization\nDPO is based on the original implementation of **\"Direct Preference Optimization: Your Language Model is Secretly a Reward Model\"** by E. Mitchell et al. \\[[paper](https://arxiv.org/pdf/2305.18290.pdf), [code](https://github.com/eric-mitchell/direct-preference-optimization)]\n\n\n## Citation\n\n```bibtex\n@misc{vonwerra2022trl,\n  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},\n  title = {TRL: Transformer Reinforcement Learning},\n  year = {2020},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/huggingface/trl}}\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Train transformer language models with reinforcement learning.",
    "version": "0.8.4",
    "project_urls": {
        "Homepage": "https://github.com/huggingface/trl"
    },
    "split_keywords": [
        "ppo",
        " transformers",
        " huggingface",
        " gpt2",
        " language modeling",
        " rlhf"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "deda74ed79b621874ad7a7d7632af97844e1736c0643a81d02ac5126ebd82dcd",
                "md5": "376f9f5d47f96411b95ec06d2799c327",
                "sha256": "7bc16a460e62d0b053008662768f6c4972e3bc632da48acf7149fb04f06ecc58"
            },
            "downloads": -1,
            "filename": "trl-0.8.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "376f9f5d47f96411b95ec06d2799c327",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 244643,
            "upload_time": "2024-04-17T15:16:47",
            "upload_time_iso_8601": "2024-04-17T15:16:47.955149Z",
            "url": "https://files.pythonhosted.org/packages/de/da/74ed79b621874ad7a7d7632af97844e1736c0643a81d02ac5126ebd82dcd/trl-0.8.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "b2b30a2222ce0e019d347e745934c4e0e8131b81fe91045f28a605e3065333a2",
                "md5": "4e2b3f1d761cfcd16ba1f3ae503c3ce8",
                "sha256": "26a761032884e27dbf027592c881b9422857203e918ec3ba2fbea9be407a00b4"
            },
            "downloads": -1,
            "filename": "trl-0.8.4.tar.gz",
            "has_sig": false,
            "md5_digest": "4e2b3f1d761cfcd16ba1f3ae503c3ce8",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 193937,
            "upload_time": "2024-04-17T15:16:50",
            "upload_time_iso_8601": "2024-04-17T15:16:50.116291Z",
            "url": "https://files.pythonhosted.org/packages/b2/b3/0a2222ce0e019d347e745934c4e0e8131b81fe91045f28a605e3065333a2/trl-0.8.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-17 15:16:50",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "huggingface",
    "github_project": "trl",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "datasets",
            "specs": [
                [
                    ">=",
                    "1.17.0"
                ]
            ]
        },
        {
            "name": "torch",
            "specs": [
                [
                    ">=",
                    "1.4.0"
                ]
            ]
        },
        {
            "name": "tqdm",
            "specs": []
        },
        {
            "name": "transformers",
            "specs": []
        },
        {
            "name": "accelerate",
            "specs": []
        },
        {
            "name": "peft",
            "specs": [
                [
                    ">=",
                    "0.3.0"
                ]
            ]
        },
        {
            "name": "tyro",
            "specs": [
                [
                    ">=",
                    "0.5.7"
                ]
            ]
        }
    ],
    "lcname": "trl"
}
        
Elapsed time: 0.25027s