instruct-goose


Nameinstruct-goose JSON
Version 0.0.7 PyPI version JSON
download
home_pagehttps://github.com/xrsrke/instructGOOSE
SummaryImplementation of Reinforcement Learning from Human Feedback (RLHF)
upload_time2023-04-03 03:02:52
maintainer
docs_urlNone
authorxrsrke
requires_python>=3.7
licenseApache Software License 2.0
keywords rlhf reinforcement-learning human-feedback chatgpt instructgpt
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            InstructGoose
================

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

Paper: InstructGPT - [Training language models to follow instructions
with human feedback](https://arxiv.org/abs/2203.02155)

![image.png](index_files/figure-commonmark/d8305522-1-image.png)

## Install

Install from PipPy

``` sh
pip install instruct-goose
```

Install directly from the source code

``` sh
git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .
```

## Train the RL-based language model

``` python
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader, random_split
from torch import optim

from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model
```

**Step 1:** Load dataset

``` python
dataset = load_dataset("imdb", split="train")
dataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10]) # for demenstration purposes
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
```

    Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)

**Step 2**: Load the pre-trained model and tokenizer

``` python
model_base = AutoModelForCausalLM.from_pretrained("gpt2") # for demonstration purposes
reward_model = RewardModel("gpt2")

tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
```

**Step 3**: Create the RL-based language model agent and the reference
model

``` python
model = Agent(model_base)
ref_model = create_reference_model(model)
```

**Step 4**: Train it

``` python
max_new_tokens = 20
generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": max_new_tokens
}

config = RLHFConfig()
N_EPOCH = 1 # for demonstration purposes
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
```

``` python
for epoch in range(N_EPOCH):
    for batch in train_dataloader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        response_ids = model.generate(
            inputs["input_ids"], attention_mask=inputs["attention_mask"],
            **generation_kwargs
        )
        
        # extract the generated text
        response_ids = response_ids[:, -max_new_tokens:]
        response_attention_mask = torch.ones_like(response_ids)
        
        # evaluate from the reward model
        with torch.no_grad():
            text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], response_ids)], dim=0)
            rewards = reward_model(text_input_ids)
        
        # calculate PPO loss
        loss = trainer.compute_loss(
            query_ids=inputs["input_ids"],
            query_attention_mask=inputs["attention_mask"],
            response_ids=response_ids,
            response_attention_mask=response_attention_mask,
            rewards=rewards
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss={loss}")
```

    loss=-824.6560668945312
    loss=0.030958056449890137
    loss=4.284017562866211

## TODO

- Add support custom reward function
- Add support custom value function
- Add support non-transformer models
- Write config class

## Resources

I implemented this using these resources

- Copied the
  [`load_yaml`](https://xrsrke.github.io/instructGOOSE/utils.html#load_yaml)
  function from https://github.com/Dahoas/reward-modeling
- How to build a dataset to train reward model:
  https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX–VmlldzozMzAwODM2
- How to add value head in PPO agent: https://github.com/lvwerra/trl
- How to calculate the loss of PPO agent:
  https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py
- How to use PPO to train RLHF agent: https://github.com/voidful/TextRL
- How PPO works:
  https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
- Copied the compute `advantages` and `returns` from `TLR`:
  https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/xrsrke/instructGOOSE",
    "name": "instruct-goose",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "",
    "keywords": "rlhf reinforcement-learning human-feedback chatgpt instructgpt",
    "author": "xrsrke",
    "author_email": "xariusdrake@hotmail.com",
    "download_url": "https://files.pythonhosted.org/packages/9b/4e/9bd9eafab6ba2a564f741645513210b87caec3ffccfae56500b8a7f29e27/instruct_goose-0.0.7.tar.gz",
    "platform": null,
    "description": "InstructGoose\n================\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\nPaper: InstructGPT - [Training language models to follow instructions\nwith human feedback](https://arxiv.org/abs/2203.02155)\n\n![image.png](index_files/figure-commonmark/d8305522-1-image.png)\n\n## Install\n\nInstall from PipPy\n\n``` sh\npip install instruct-goose\n```\n\nInstall directly from the source code\n\n``` sh\ngit clone https://github.com/xrsrke/instructGOOSE.git\ncd instructGOOSE\npip install -e .\n```\n\n## Train the RL-based language model\n\n``` python\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom datasets import load_dataset\n\nimport torch\nfrom torch.utils.data import DataLoader, random_split\nfrom torch import optim\n\nfrom instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model\n```\n\n**Step 1:** Load dataset\n\n``` python\ndataset = load_dataset(\"imdb\", split=\"train\")\ndataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10]) # for demenstration purposes\ntrain_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)\n```\n\n    Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n\n**Step 2**: Load the pre-trained model and tokenizer\n\n``` python\nmodel_base = AutoModelForCausalLM.from_pretrained(\"gpt2\") # for demonstration purposes\nreward_model = RewardModel(\"gpt2\")\n\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\neos_token_id = tokenizer.eos_token_id\ntokenizer.pad_token = tokenizer.eos_token\n```\n\n**Step 3**: Create the RL-based language model agent and the reference\nmodel\n\n``` python\nmodel = Agent(model_base)\nref_model = create_reference_model(model)\n```\n\n**Step 4**: Train it\n\n``` python\nmax_new_tokens = 20\ngeneration_kwargs = {\n    \"min_length\":-1,\n    \"top_k\": 0.0,\n    \"top_p\": 1.0,\n    \"do_sample\": True,\n    \"pad_token_id\": tokenizer.eos_token_id,\n    \"max_new_tokens\": max_new_tokens\n}\n\nconfig = RLHFConfig()\nN_EPOCH = 1 # for demonstration purposes\ntrainer = RLHFTrainer(model, ref_model, config)\noptimizer = optim.SGD(model.parameters(), lr=1e-3)\n```\n\n``` python\nfor epoch in range(N_EPOCH):\n    for batch in train_dataloader:\n        inputs = tokenizer(batch[\"text\"], padding=True, truncation=True, return_tensors=\"pt\")\n        response_ids = model.generate(\n            inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"],\n            **generation_kwargs\n        )\n        \n        # extract the generated text\n        response_ids = response_ids[:, -max_new_tokens:]\n        response_attention_mask = torch.ones_like(response_ids)\n        \n        # evaluate from the reward model\n        with torch.no_grad():\n            text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs[\"input_ids\"], response_ids)], dim=0)\n            rewards = reward_model(text_input_ids)\n        \n        # calculate PPO loss\n        loss = trainer.compute_loss(\n            query_ids=inputs[\"input_ids\"],\n            query_attention_mask=inputs[\"attention_mask\"],\n            response_ids=response_ids,\n            response_attention_mask=response_attention_mask,\n            rewards=rewards\n        )\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        print(f\"loss={loss}\")\n```\n\n    loss=-824.6560668945312\n    loss=0.030958056449890137\n    loss=4.284017562866211\n\n## TODO\n\n- Add support custom reward function\n- Add support custom value function\n- Add support non-transformer models\n- Write config class\n\n## Resources\n\nI implemented this using these resources\n\n- Copied the\n  [`load_yaml`](https://xrsrke.github.io/instructGOOSE/utils.html#load_yaml)\n  function from https://github.com/Dahoas/reward-modeling\n- How to build a dataset to train reward model:\n  https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX\u2013VmlldzozMzAwODM2\n- How to add value head in PPO agent: https://github.com/lvwerra/trl\n- How to calculate the loss of PPO agent:\n  https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py\n- How to use PPO to train RLHF agent: https://github.com/voidful/TextRL\n- How PPO works:\n  https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py\n- Copied the compute `advantages` and `returns` from `TLR`:\n  https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686\n",
    "bugtrack_url": null,
    "license": "Apache Software License 2.0",
    "summary": "Implementation of Reinforcement Learning from Human Feedback (RLHF)",
    "version": "0.0.7",
    "split_keywords": [
        "rlhf",
        "reinforcement-learning",
        "human-feedback",
        "chatgpt",
        "instructgpt"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "14f027ef27a25ed93d747363926cb4516b146317889f408fc18291a1e8228459",
                "md5": "e7fc7e0ad43baaf91da421d8d14ed41a",
                "sha256": "d04fd839a81f82ca03272d94d74c5f34ec52c9458292a55e64d1a85a51ec0bcd"
            },
            "downloads": -1,
            "filename": "instruct_goose-0.0.7-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "e7fc7e0ad43baaf91da421d8d14ed41a",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 12746,
            "upload_time": "2023-04-03T03:02:50",
            "upload_time_iso_8601": "2023-04-03T03:02:50.551722Z",
            "url": "https://files.pythonhosted.org/packages/14/f0/27ef27a25ed93d747363926cb4516b146317889f408fc18291a1e8228459/instruct_goose-0.0.7-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "9b4e9bd9eafab6ba2a564f741645513210b87caec3ffccfae56500b8a7f29e27",
                "md5": "e8349b2024d7777732fef99094f6545d",
                "sha256": "532aa9676e27e9e8c570d5663bb6e2e55de1765fcda2c3c0f1b666cfb0c05877"
            },
            "downloads": -1,
            "filename": "instruct_goose-0.0.7.tar.gz",
            "has_sig": false,
            "md5_digest": "e8349b2024d7777732fef99094f6545d",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 11546,
            "upload_time": "2023-04-03T03:02:52",
            "upload_time_iso_8601": "2023-04-03T03:02:52.715879Z",
            "url": "https://files.pythonhosted.org/packages/9b/4e/9bd9eafab6ba2a564f741645513210b87caec3ffccfae56500b8a7f29e27/instruct_goose-0.0.7.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-04-03 03:02:52",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "xrsrke",
    "github_project": "instructGOOSE",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "instruct-goose"
}
        
Elapsed time: 0.67003s