Skip to content

Baichenjia/COPO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

COPO: Count-based Online Preference Optimization

Code for Online Preference Alignment for Language Models via Count-based Exploration. Paper link.

Prerequisites

To run the code in this project, first, create a Python virtual environment using e.g. Conda:

conda create -n rlhf python=3.10 && conda activate rlhf

You can then install the remaining package dependencies as follows:

 python -m pip install .

You will also need Flash Attention 2 installed, which can be done by running:

python -m pip install flash-attn==2.3.6 --no-build-isolation

Next, log into your Hugging Face account as follows:

huggingface-cli login

Finally, install Git LFS so that you can push models to the Hugging Face Hub:

sudo apt-get install git-lfs

Run the code

To train COPO on Meta-Llama-3-8B-Instruct, you need to first apply for the access.

You should find Huggingface key and Wandb key in your account, and then replace the xxxx in HUGGINGFACE_API_KEY and WANDB_API_KEY in recipes/launch_llama3_copo.sh. Also, you should replace the path in BASE_DIR with your own path to COPO root directory.

After the above preparation, train COPO on Meta-Llama-3-8B-Instruct:

cd COPO/recipes
sh launch_llama3_copo.sh

After training, the collected data will be saved in dataset/, and the learned model will be saved in data/. For each iteration, we use an independent dictionary to save it, the final model will be saved in data/COPO-Llama-3-8B-Instruct-iter-3/merge.

File organization

The important implementation is saved at:

  • scripts/online_feedback.py: generate responses based on a prompt set of preference datasets, and adopt PairRM to rank responses to obtain the online preference dataset.

  • scripts/count_trainer.py: train the coin-flipping network based on the generated data, which is used to perform count-based exploration.

  • scripts/copo_trainer.py: the main process to perform updates with COPO algorithm, which combines online DPO and CFN objective.

  • scripts/merge_model.py: merge the Lora layer and the main network to a new model

Acknowledgement

This repo is built upon The Alignment Handbook. We thank the authors for their great work.

About

Online Preference Alignment for Language Models via Count-based Exploration

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published