diff --git a/openvalidators/config.py b/openvalidators/config.py index 4c00c21..2a5c69d 100644 --- a/openvalidators/config.py +++ b/openvalidators/config.py @@ -265,6 +265,12 @@ def add_args(cls, parser): help="Weight for the reciprocate reward model", default=DefaultRewardFrameworkConfig.reciprocate_model_weight, ) + parser.add_argument( + "--reward.dpo_weight", + type=float, + help="Weight for the dpo reward model", + default=DefaultRewardFrameworkConfig.dpo_model_weight, + ) parser.add_argument( "--reward.rlhf_weight", type=float, diff --git a/openvalidators/event.py b/openvalidators/event.py index 8aaec40..48ebabd 100644 --- a/openvalidators/event.py +++ b/openvalidators/event.py @@ -41,6 +41,7 @@ class EventSchema: nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter reciprocate_reward_model: Optional[List[float]] # Output vector of the reciprocate reward model diversity_reward_model: Optional[List[float]] # Output vector of the diversity reward model + dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model prompt_reward_model: Optional[List[float]] # Output vector of the prompt reward model relevance_filter: Optional[List[float]] # Output vector of the relevance scoring reward model @@ -60,6 +61,7 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema': 'relevance_filter': event_dict.get(RewardModelType.relevance.value), 'reciprocate_reward_model': event_dict.get(RewardModelType.reciprocate.value), 'diversity_reward_model': event_dict.get(RewardModelType.diversity.value), + 'dpo_reward_model': event_dict.get(RewardModelType.dpo.value), 'rlhf_reward_model': event_dict.get(RewardModelType.rlhf.value), 'prompt_reward_model': event_dict.get(RewardModelType.prompt.value), } diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py index bbbfd7e..f21e670 100644 --- a/openvalidators/neuron.py +++ b/openvalidators/neuron.py @@ -36,6 +36,7 @@ Blacklist, TaskValidator, NSFWRewardModel, + DirectPreferenceRewardModel, OpenAssistantRewardModel, ReciprocateRewardModel, RelevanceRewardModel, @@ -174,6 +175,7 @@ def __init__(self): else: self.reward_weights = torch.tensor( [ + self.config.reward.dpo_weight, self.config.reward.rlhf_weight, self.config.reward.reciprocate_weight, self.config.reward.dahoas_weight, @@ -192,6 +194,9 @@ def __init__(self): raise Exception(message) self.reward_functions = [ + DirectPreferenceRewardModel(device=self.device) + if self.config.reward.dpo_weight > 0 + else MockRewardModel(RewardModelType.dpo.value), OpenAssistantRewardModel(device=self.device) if self.config.reward.rlhf_weight > 0 else MockRewardModel(RewardModelType.rlhf.value), diff --git a/openvalidators/reward/__init__.py b/openvalidators/reward/__init__.py index c330866..51a20f9 100644 --- a/openvalidators/reward/__init__.py +++ b/openvalidators/reward/__init__.py @@ -1,6 +1,7 @@ from .blacklist import Blacklist from .task_validator import TaskValidator from .nsfw import NSFWRewardModel +from .dpo import DirectPreferenceRewardModel from .open_assistant import OpenAssistantRewardModel from .reciprocate import ReciprocateRewardModel from .relevance import RelevanceRewardModel diff --git a/openvalidators/reward/config.py b/openvalidators/reward/config.py index d3b8376..cdab7d0 100644 --- a/openvalidators/reward/config.py +++ b/openvalidators/reward/config.py @@ -18,6 +18,7 @@ class RewardModelType(Enum): + dpo = 'dpo_reward_model' rlhf = 'rlhf_reward_model' reciprocate = 'reciprocate_reward_model' dahoas = 'dahoas_reward_model' @@ -34,7 +35,8 @@ class DefaultRewardFrameworkConfig: """Reward framework default configuration. Note: All the weights should add up to 1.0. """ - rlhf_model_weight: float = 0.6 + dpo_model_weight: float = 0.2 + rlhf_model_weight: float = 0.4 reciprocate_model_weight: float = 0.4 dahoas_model_weight: float = 0 prompt_model_weight: float = 0 diff --git a/openvalidators/reward/dpo.py b/openvalidators/reward/dpo.py new file mode 100644 index 0000000..3a5860b --- /dev/null +++ b/openvalidators/reward/dpo.py @@ -0,0 +1,93 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +import bittensor as bt +from typing import List +from .config import RewardModelType +from .reward import BaseRewardModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +class DirectPreferenceRewardModel(BaseRewardModel): + + reward_model_name: str = "cerebras/btlm-3b-8k-base" + + @property + def name(self) -> str: return RewardModelType.dpo.value + + def __init__(self, device: str): + super().__init__() + self.device = device + self.tokenizer = AutoTokenizer.from_pretrained(DirectPreferenceRewardModel.reward_model_name) + self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name, + trust_remote_code=True, + torch_dtype=torch.float16).to(self.device) + + def reward_single(self, prompt: str, completion: str, name: str) -> float: + r""" Calculates a direct preference optimization (DPO) style reward for a completion, + which is a reference model's average log-probability for completion tokens given a prompt. + Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py. + """ + with torch.no_grad(): + # Tokenize the combined prompt + completion. + combined = self.tokenizer(prompt + completion, return_tensors="pt").input_ids[0].to(self.device) # [seq_len] + # Tokenize only the prompt, to help determine prompt token length. + prompt_part = self.tokenizer(prompt, return_tensors="pt").input_ids[0].to(self.device) # [prompt_len] + + # Completion doesn't fit into model sequence, so return lowest reward. + if self.tokenizer.model_max_length <= len(prompt_part): + return -11. # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + + # Truncate combined to fit into model max sequence length. + if self.tokenizer.model_max_length < len(combined): + combined = combined[:self.tokenizer.model_max_length] + + labels = combined.clone() # [seq_len] + # Ignore prompt part for calculating reward. + labels[:len(prompt_part)] = -100 + # Label only each next token prediction ground-truth. + labels = labels[1:] # [seq_len-1] + loss_mask = (labels != -100) # [seq_len-1] + # Dummy token to allow for indexing, but loss will be ignored. + labels[labels == -100] = 0 + # Reshape for gather operation. + labels = labels.unsqueeze(0).unsqueeze(2) # [batch_size=1, seq_len-1, :] + + # Forward pass to calculate logit predictions for each sequence position. + logits = self.model(combined.unsqueeze(0)).logits # [batch_size=1, seq_len, vocab_len] + # Predict only where labels are available. + logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len] + + # Rescale via log(softmax(logits)). + logits = logits.log_softmax(-1) + # Calculate the model's log-probability for each actual completion token. + per_token_logps = torch.gather(logits, dim=2, index=labels).squeeze(2) # [batch_size=1, seq_len-1] + # Average log-probability over completion sequence. + reward = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) # [batch_size=1] + reward = reward[0].cpu().detach() + + # NaNs can possibly arise through log(0)=-inf, replace with suitably small logits. + if torch.isnan(reward) or torch.isinf(reward): + return -11. # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + return reward.item() + + def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.FloatTensor: + rewards = torch.tensor([self.reward_single(prompt, completion, name) for completion in completions], + dtype=torch.float32).to(self.device) + bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}") + return rewards diff --git a/tests/test_event.py b/tests/test_event.py index 7fb9f2b..ff366f5 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -42,6 +42,7 @@ def test_event_from_dict_all_forward_columns_match(self): RewardModelType.nsfw.value: [1.0], RewardModelType.reciprocate.value: [1.0], RewardModelType.diversity.value: [1.0], + RewardModelType.dpo.value: [1.0], RewardModelType.rlhf.value: [1.0], RewardModelType.prompt.value: [1.0], RewardModelType.relevance.value: [1.0], @@ -100,6 +101,7 @@ def test_event_from_dict_forward_no_reward_logging(self): assert event.nsfw_filter is None assert event.reciprocate_reward_model is None assert event.diversity_reward_model is None + assert event.dpo_reward_model is None assert event.rlhf_reward_model is None assert event.prompt_reward_model is None assert event.relevance_filter is None @@ -141,6 +143,7 @@ def test_event_from_dict_forward_reward_logging_mismatch(self): assert event.nsfw_filter is None assert event.reciprocate_reward_model is None assert event.diversity_reward_model is None + assert event.dpo_reward_model is None assert event.rlhf_reward_model is None assert event.prompt_reward_model is None assert event.relevance_filter is None