<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/logo.png" width=90% alt="OAT" />
</p>
[![PyPI - Version](https://img.shields.io/pypi/v/oat-llm.svg)](https://pypi.org/project/oat-llm)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/oat-llm.svg)](https://pypi.org/project/oat-llm)
[![License](https://img.shields.io/github/license/sail-sg/oat)](https://github.com/sail-sg/oat/blob/main/LICENSE)
[![arXiv](https://img.shields.io/badge/arXiv-2411.01493-b31b1b.svg)](https://arxiv.org/abs/2411.01493)
[Installation](#installation) | [Usage](#usage) | [Examples](./examples/) | [Benchmarking](#benchmarking) | [Citation](#citation)
---
## Introduction
Oat 🌾 is a simple yet efficient system for running online LLM alignment algorithms. Its key features include:
* **High Efficiency**: Oat implements a distributed *Actor-Learner-Oracle* architecture, with each component being optimized using state-of-the-art tools:
* `Actor`: Utilizes [vLLM](https://github.com/vllm-project/vllm) for accelerated online response sampling.
* `Learner`: Leverages [DeepSpeed](https://github.com/microsoft/DeepSpeed) ZeRO strategies to enhance memory efficiency.
* `Oracle`: Hosted by [Mosec](https://github.com/mosecorg/mosec) as a remote service, supporting dynamic batching, data parallelism and pipeline parallelism.
* **Simplified Workflow**: Oat simplifies the experimental pipeline of LLM alignment. With an `Oracle` served online, we can flexibly query it for preference data labeling as well as anytime model evaluation. All you need is to launch experiments and monitor real-time learning curves (e.g., win rate) on wandb (see [reproduced results](https://wandb.ai/lkevinzc/oat-llm)) — no need for manual training, checkpointing and loading for evaluation.
* **Oracle Simulation**: Oat provides simulated preference oracles in various modes.
* Lightweight reward models run within the actor's process, enabling quick testing on as few as two GPUs.
* Larger and more capable reward models can be served remotely, harnessing additional compute and memory resources.
* LLM-as-a-judge is supported via querying OpenAI API for model-based pairwise ranking.
* **Ease of Use**: Oat's modular structure allows researchers to easily inherit and modify existing classes, enabling rapid prototyping and experimentation with new algorithms.
* **Cutting-Edge Algorithms**: Oat implements state-of-the-art LLM exploration (active alignment) algorithms, including [SEA](https://arxiv.org/abs/2411.01493), APL and XPO, along with popular direct optimizers such as DPO and SimPO, fostering innovation and fair benchmarking.
## LLM alignment as contextual dueling bandits
LLM alignment is essentially an online learning and decision making problem where the **agent** (e.g., the LLM policy with an optional built-in reward model) interacts with the **environment** (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the *Explore & Exploit* setting or minimizing anytime regret in the *Best Arm Identification* setting.
In our [paper](https://arxiv.org/abs/2411.01493), we formalize LLM alignment as a **contextual dueling bandit (CDB)** problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png" width=80%/>
</p>
The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative.
Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png" width=95%/>
</p>
For more details, please check out our [paper](https://arxiv.org/abs/2411.01493)!
## Installation
In a python environment with supported versions (`>=3.8, <=3.10`), you could install oat via PyPI:
```shell
pip install vllm==0.6.2 && pip install oat-llm
```
Or you could also install in "editable" mode for local development:
```shell
git clone git@github.com:sail-sg/oat.git
cd oat
pip install vllm==0.6.2 && pip install -e .
```
## Usage
Below is an example to align a `1-B Pythia` SFT Model on the `tl;dr` dataset using `online SimPO` with `PairRM` as the preference oracle:
> [!WARNING]
> Aligning with `PairRM` provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the [examples](./examples/README.md#preference-oracles).
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
--wb-run-name 1b_pairrm_simpo_online
```
This example completes in **less than two hours on two A100-40G GPUs**!
To run an `offline SimPO` baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the `sync-params-every` argument:
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
- --sync-params-every 1 \
+ --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_offline
```
Finally, we run `SEA SimPO` (with $\gamma=1$, see [here](https://arxiv.org/pdf/2411.01493#page=7.60) for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.
```diff
python -m oat.experiment.main \
- --gpus 2 \
+ --gpus 4 \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
- --rollout-batch-size-per-device 64 \
- --pi-buffer-maxlen-per-device 64 \
- --train-batch-size-per-device 8 \
+ --rollout-batch-size-per-device 32 \
+ --pi-buffer-maxlen-per-device 32 \
+ --train-batch-size-per-device 1 \
+ --learn-rm \
+ --exp-method EnnBAITS \
+ --num_samples 10 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_sea
```
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png" width=55%/>
</p>
Check out this [tutorial](./examples/) for more examples covering:
* Various direct optimizers, including DPO, IPO, and SLiC.
* Different modes of preference oracles, such as remote reward models and GPT-as-a-judge.
* Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM.
## Benchmarking
The benchmarking compares oat with the online DPO implementation from [huggingface/trl](https://huggingface.co/docs/trl/main/en/online_dpo_trainer). Below, we outline the configurations used for oat and present the benchmarking results. Notably, oat 🌾 achieves up to **2.5x** computational efficiency compared to trl 🤗.
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/system_configs.png" width=97%/>
</p>
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/bench_results.png" width=65% />
</p>
Please refer to [Appendix C of our paper](https://arxiv.org/pdf/2411.01493#page=17.64) for a detailed discussion of the benchmarking methods and results.
## Citation
If you find this work useful for your research, please consider citing
```
@article{
liu2024sea,
title={Sample-Efficient Alignment for LLMs},
author={Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
journal={arXiv preprint arXiv:2411.01493},
year={2024}
}
```
## License
`oat` is distributed under the terms of the [Apache2](https://www.apache.org/licenses/LICENSE-2.0) license.
## Acknowledgement
We thank the following awesome projects that have contributed to the development of oat:
* [vLLM](https://github.com/vllm-project/vllm)
* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
* [Mosec](https://github.com/mosecorg/mosec)
* [launchpad](https://github.com/google-deepmind/launchpad)
* [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)
## Disclaimer
This is not an official Sea Limited or Garena Online Private Limited product.
Raw data
{
"_id": null,
"home_page": "https://github.com/sail-sg/oat",
"name": "oat-llm",
"maintainer": null,
"docs_url": null,
"requires_python": "<3.11,>=3.8",
"maintainer_email": null,
"keywords": "rlhf, llm, ai-alignment, rl, bandit, ai, sample-efficiency",
"author": "Zichen Liu",
"author_email": "Zichen Liu <liuzc@sea.com>, Changyu Chen <chency@sea.com>",
"download_url": null,
"platform": null,
"description": "<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/logo.png\" width=90% alt=\"OAT\" />\n</p>\n\n[![PyPI - Version](https://img.shields.io/pypi/v/oat-llm.svg)](https://pypi.org/project/oat-llm)\n[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/oat-llm.svg)](https://pypi.org/project/oat-llm)\n[![License](https://img.shields.io/github/license/sail-sg/oat)](https://github.com/sail-sg/oat/blob/main/LICENSE)\n[![arXiv](https://img.shields.io/badge/arXiv-2411.01493-b31b1b.svg)](https://arxiv.org/abs/2411.01493)\n\n[Installation](#installation) | [Usage](#usage) | [Examples](./examples/) | [Benchmarking](#benchmarking) | [Citation](#citation)\n\n---\n\n## Introduction\n\nOat \ud83c\udf3e is a simple yet efficient system for running online LLM alignment algorithms. Its key features include:\n\n* **High Efficiency**: Oat implements a distributed *Actor-Learner-Oracle* architecture, with each component being optimized using state-of-the-art tools:\n * `Actor`: Utilizes [vLLM](https://github.com/vllm-project/vllm) for accelerated online response sampling.\n * `Learner`: Leverages [DeepSpeed](https://github.com/microsoft/DeepSpeed) ZeRO strategies to enhance memory efficiency.\n * `Oracle`: Hosted by [Mosec](https://github.com/mosecorg/mosec) as a remote service, supporting dynamic batching, data parallelism and pipeline parallelism.\n* **Simplified Workflow**: Oat simplifies the experimental pipeline of LLM alignment. With an `Oracle` served online, we can flexibly query it for preference data labeling as well as anytime model evaluation. All you need is to launch experiments and monitor real-time learning curves (e.g., win rate) on wandb (see [reproduced results](https://wandb.ai/lkevinzc/oat-llm)) \u2014 no need for manual training, checkpointing and loading for evaluation.\n* **Oracle Simulation**: Oat provides simulated preference oracles in various modes.\n * Lightweight reward models run within the actor's process, enabling quick testing on as few as two GPUs.\n * Larger and more capable reward models can be served remotely, harnessing additional compute and memory resources.\n * LLM-as-a-judge is supported via querying OpenAI API for model-based pairwise ranking.\n* **Ease of Use**: Oat's modular structure allows researchers to easily inherit and modify existing classes, enabling rapid prototyping and experimentation with new algorithms.\n* **Cutting-Edge Algorithms**: Oat implements state-of-the-art LLM exploration (active alignment) algorithms, including [SEA](https://arxiv.org/abs/2411.01493), APL and XPO, along with popular direct optimizers such as DPO and SimPO, fostering innovation and fair benchmarking.\n\n## LLM alignment as contextual dueling bandits\n\nLLM alignment is essentially an online learning and decision making problem where the **agent** (e.g., the LLM policy with an optional built-in reward model) interacts with the **environment** (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the *Explore & Exploit* setting or minimizing anytime regret in the *Best Arm Identification* setting.\n\nIn our [paper](https://arxiv.org/abs/2411.01493), we formalize LLM alignment as a **contextual dueling bandit (CDB)** problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.\n\n<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png\" width=80%/>\n</p>\n\nThe CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat \ud83c\udf3e is developed as part of this research initiative.\n\nUsing the CDB framework, existing LLM alignment paradigms can be summarized as follows:\n\n<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png\" width=95%/>\n</p>\n\nFor more details, please check out our [paper](https://arxiv.org/abs/2411.01493)!\n\n## Installation\nIn a python environment with supported versions (`>=3.8, <=3.10`), you could install oat via PyPI:\n```shell\npip install vllm==0.6.2 && pip install oat-llm\n```\nOr you could also install in \"editable\" mode for local development:\n```shell\ngit clone git@github.com:sail-sg/oat.git\ncd oat\npip install vllm==0.6.2 && pip install -e .\n```\n\n## Usage\nBelow is an example to align a `1-B Pythia` SFT Model on the `tl;dr` dataset using `online SimPO` with `PairRM` as the preference oracle:\n\n> [!WARNING]\n> Aligning with `PairRM` provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the [examples](./examples/README.md#preference-oracles).\n\n```diff\npython -m oat.experiment.main \\\n --gpus 2 \\\n --collocate \\\n --dap-algo SimPO \\\n --beta 2 \\\n --preference-oracle pairrm \\\n --pretrain trl-lib/pythia-1b-deduped-tldr-sft \\\n --prompt-data lkevinzc/tldr-with-sft-reference \\\n --output_key pythia-1b-reference \\\n --sync-params-every 1 \\\n --rollout-batch-size-per-device 64 \\\n --pi-buffer-maxlen-per-device 64 \\\n --train-batch-size-per-device 8 \\\n --use-wb \\\n --wb-run-name 1b_pairrm_simpo_online\n```\nThis example completes in **less than two hours on two A100-40G GPUs**!\n\nTo run an `offline SimPO` baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the `sync-params-every` argument:\n```diff\npython -m oat.experiment.main \\\n --gpus 2 \\\n --collocate \\\n --dap-algo SimPO \\\n --beta 2 \\\n --preference-oracle pairrm \\\n --pretrain trl-lib/pythia-1b-deduped-tldr-sft \\\n --prompt-data lkevinzc/tldr-with-sft-reference \\\n --output_key pythia-1b-reference \\\n- --sync-params-every 1 \\\n+ --sync-params-every 9999 \\ # any number > total gradient step (50000//128=390)\n --rollout-batch-size-per-device 64 \\\n --pi-buffer-maxlen-per-device 64 \\\n --train-batch-size-per-device 8 \\\n --use-wb \\\n- --wb-run-name 1b_pairrm_simpo_online\n+ --wb-run-name 1b_pairrm_simpo_offline\n```\n\nFinally, we run `SEA SimPO` (with $\\gamma=1$, see [here](https://arxiv.org/pdf/2411.01493#page=7.60) for the meaning of $\\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.\n```diff\npython -m oat.experiment.main \\\n- --gpus 2 \\\n+ --gpus 4 \\\n --dap-algo SimPO \\\n --beta 2 \\\n --preference-oracle pairrm \\\n --pretrain trl-lib/pythia-1b-deduped-tldr-sft \\\n --prompt-data lkevinzc/tldr-with-sft-reference \\\n --output_key pythia-1b-reference \\\n --sync-params-every 1 \\\n- --rollout-batch-size-per-device 64 \\\n- --pi-buffer-maxlen-per-device 64 \\\n- --train-batch-size-per-device 8 \\\n+ --rollout-batch-size-per-device 32 \\\n+ --pi-buffer-maxlen-per-device 32 \\\n+ --train-batch-size-per-device 1 \\\n+ --learn-rm \\\n+ --exp-method EnnBAITS \\\n+ --num_samples 10 \\\n --use-wb \\\n- --wb-run-name 1b_pairrm_simpo_online\n+ --wb-run-name 1b_pairrm_simpo_sea\n```\n\n<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png\" width=55%/>\n</p>\n\nCheck out this [tutorial](./examples/) for more examples covering:\n* Various direct optimizers, including DPO, IPO, and SLiC.\n* Different modes of preference oracles, such as remote reward models and GPT-as-a-judge.\n* Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM.\n\n## Benchmarking\nThe benchmarking compares oat with the online DPO implementation from [huggingface/trl](https://huggingface.co/docs/trl/main/en/online_dpo_trainer). Below, we outline the configurations used for oat and present the benchmarking results. Notably, oat \ud83c\udf3e achieves up to **2.5x** computational efficiency compared to trl \ud83e\udd17.\n\n<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/system_configs.png\" width=97%/>\n</p>\n\n<p align=\"center\">\n <img src=\"https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/bench_results.png\" width=65% />\n</p>\n\nPlease refer to [Appendix C of our paper](https://arxiv.org/pdf/2411.01493#page=17.64) for a detailed discussion of the benchmarking methods and results.\n\n## Citation\nIf you find this work useful for your research, please consider citing\n```\n@article{\n liu2024sea,\n title={Sample-Efficient Alignment for LLMs},\n author={Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},\n journal={arXiv preprint arXiv:2411.01493},\n year={2024}\n}\n```\n\n## License\n\n`oat` is distributed under the terms of the [Apache2](https://www.apache.org/licenses/LICENSE-2.0) license.\n\n## Acknowledgement\nWe thank the following awesome projects that have contributed to the development of oat:\n* [vLLM](https://github.com/vllm-project/vllm)\n* [DeepSpeed](https://github.com/microsoft/DeepSpeed)\n* [Mosec](https://github.com/mosecorg/mosec)\n* [launchpad](https://github.com/google-deepmind/launchpad)\n* [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)\n\n## Disclaimer\n\nThis is not an official Sea Limited or Garena Online Private Limited product.\n",
"bugtrack_url": null,
"license": "Apache-2.0",
"summary": "Online AlignmenT (OAT) for LLMs.",
"version": "0.0.4",
"project_urls": {
"Documentation": "https://github.com/sail-sg/oat#readme",
"Homepage": "https://github.com/sail-sg/oat",
"Issues": "https://github.com/sail-sg/oat/issues",
"Source": "https://github.com/sail-sg/oat"
},
"split_keywords": [
"rlhf",
" llm",
" ai-alignment",
" rl",
" bandit",
" ai",
" sample-efficiency"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "193a705bf2f74e645de377d67bff46504daa7d0c159f76967009deff23d9b4de",
"md5": "21a8861e8e6d7369bdade59525bd31df",
"sha256": "14fe63e9b2defbcff4bd38e3078350fe0523fa2704ed41f4cd4304520bc3dc99"
},
"downloads": -1,
"filename": "oat_llm-0.0.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "21a8861e8e6d7369bdade59525bd31df",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<3.11,>=3.8",
"size": 83708,
"upload_time": "2024-11-11T09:26:07",
"upload_time_iso_8601": "2024-11-11T09:26:07.017204Z",
"url": "https://files.pythonhosted.org/packages/19/3a/705bf2f74e645de377d67bff46504daa7d0c159f76967009deff23d9b4de/oat_llm-0.0.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-11-11 09:26:07",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "sail-sg",
"github_project": "oat",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "oat-llm"
}