Skip to content
6 changes: 6 additions & 0 deletions openvalidators/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def add_args(cls, parser):
help="Weight for the reciprocate reward model",
default=DefaultRewardFrameworkConfig.reciprocate_model_weight,
)
parser.add_argument(
"--reward.dpo_weight",
type=float,
help="Weight for the dpo reward model",
default=DefaultRewardFrameworkConfig.dpo_model_weight,
)
parser.add_argument(
"--reward.rlhf_weight",
type=float,
Expand Down
2 changes: 2 additions & 0 deletions openvalidators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EventSchema:
nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model: Optional[List[float]] # Output vector of the reciprocate reward model
diversity_reward_model: Optional[List[float]] # Output vector of the diversity reward model
dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model
rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model: Optional[List[float]] # Output vector of the prompt reward model
relevance_filter: Optional[List[float]] # Output vector of the relevance scoring reward model
Expand All @@ -60,6 +61,7 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema':
'relevance_filter': event_dict.get(RewardModelType.relevance.value),
'reciprocate_reward_model': event_dict.get(RewardModelType.reciprocate.value),
'diversity_reward_model': event_dict.get(RewardModelType.diversity.value),
'dpo_reward_model': event_dict.get(RewardModelType.dpo.value),
'rlhf_reward_model': event_dict.get(RewardModelType.rlhf.value),
'prompt_reward_model': event_dict.get(RewardModelType.prompt.value),
}
Expand Down
5 changes: 5 additions & 0 deletions openvalidators/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Blacklist,
TaskValidator,
NSFWRewardModel,
DirectPreferenceRewardModel,
OpenAssistantRewardModel,
ReciprocateRewardModel,
RelevanceRewardModel,
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(self):
else:
self.reward_weights = torch.tensor(
[
self.config.reward.dpo_weight,
self.config.reward.rlhf_weight,
self.config.reward.reciprocate_weight,
self.config.reward.dahoas_weight,
Expand All @@ -192,6 +194,9 @@ def __init__(self):
raise Exception(message)

self.reward_functions = [
DirectPreferenceRewardModel(device=self.device)
if self.config.reward.dpo_weight > 0
else MockRewardModel(RewardModelType.dpo.value),
OpenAssistantRewardModel(device=self.device)
if self.config.reward.rlhf_weight > 0
else MockRewardModel(RewardModelType.rlhf.value),
Expand Down
1 change: 1 addition & 0 deletions openvalidators/reward/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .blacklist import Blacklist
from .task_validator import TaskValidator
from .nsfw import NSFWRewardModel
from .dpo import DirectPreferenceRewardModel
from .open_assistant import OpenAssistantRewardModel
from .reciprocate import ReciprocateRewardModel
from .relevance import RelevanceRewardModel
Expand Down
4 changes: 3 additions & 1 deletion openvalidators/reward/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


class RewardModelType(Enum):
dpo = 'dpo_reward_model'
rlhf = 'rlhf_reward_model'
reciprocate = 'reciprocate_reward_model'
dahoas = 'dahoas_reward_model'
Expand All @@ -34,7 +35,8 @@ class DefaultRewardFrameworkConfig:
"""Reward framework default configuration.
Note: All the weights should add up to 1.0.
"""
rlhf_model_weight: float = 0.6
dpo_model_weight: float = 0.2
rlhf_model_weight: float = 0.4
reciprocate_model_weight: float = 0.4
dahoas_model_weight: float = 0
prompt_model_weight: float = 0
93 changes: 93 additions & 0 deletions openvalidators/reward/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# The MIT License (MIT)
# Copyright © 2021 Yuma Rao

# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.

# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import torch
import bittensor as bt
from typing import List
from .config import RewardModelType
from .reward import BaseRewardModel
from transformers import AutoTokenizer, AutoModelForCausalLM


class DirectPreferenceRewardModel(BaseRewardModel):

reward_model_name: str = "cerebras/btlm-3b-8k-base"

@property
def name(self) -> str: return RewardModelType.dpo.value

def __init__(self, device: str):
super().__init__()
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(DirectPreferenceRewardModel.reward_model_name)
self.model = AutoModelForCausalLM.from_pretrained(DirectPreferenceRewardModel.reward_model_name,
trust_remote_code=True,
torch_dtype=torch.float16).to(self.device)

def reward_single(self, prompt: str, completion: str, name: str) -> float:
r""" Calculates a direct preference optimization (DPO) style reward for a completion,
which is a reference model's average log-probability for completion tokens given a prompt.
Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py.
"""
with torch.no_grad():
# Tokenize the combined prompt + completion.
combined = self.tokenizer(prompt + completion, return_tensors="pt").input_ids[0].to(self.device) # [seq_len]
# Tokenize only the prompt, to help determine prompt token length.
prompt_part = self.tokenizer(prompt, return_tensors="pt").input_ids[0].to(self.device) # [prompt_len]

# Completion doesn't fit into model sequence, so return lowest reward.
if self.tokenizer.model_max_length <= len(prompt_part):
return -11. # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size)

# Truncate combined to fit into model max sequence length.
if self.tokenizer.model_max_length < len(combined):
combined = combined[:self.tokenizer.model_max_length]

labels = combined.clone() # [seq_len]
# Ignore prompt part for calculating reward.
labels[:len(prompt_part)] = -100
# Label only each next token prediction ground-truth.
labels = labels[1:] # [seq_len-1]
loss_mask = (labels != -100) # [seq_len-1]
# Dummy token to allow for indexing, but loss will be ignored.
labels[labels == -100] = 0
# Reshape for gather operation.
labels = labels.unsqueeze(0).unsqueeze(2) # [batch_size=1, seq_len-1, :]

# Forward pass to calculate logit predictions for each sequence position.
logits = self.model(combined.unsqueeze(0)).logits # [batch_size=1, seq_len, vocab_len]
# Predict only where labels are available.
logits = logits[:, :-1, :] # [batch_size=1, seq_len-1, vocab_len]

# Rescale via log(softmax(logits)).
logits = logits.log_softmax(-1)
# Calculate the model's log-probability for each actual completion token.
per_token_logps = torch.gather(logits, dim=2, index=labels).squeeze(2) # [batch_size=1, seq_len-1]
# Average log-probability over completion sequence.
reward = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) # [batch_size=1]
reward = reward[0].cpu().detach()

# NaNs can possibly arise through log(0)=-inf, replace with suitably small logits.
if torch.isnan(reward) or torch.isinf(reward):
return -11. # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size)
return reward.item()

def get_rewards(self, prompt: str, completions: List[str], name: str) -> torch.FloatTensor:
rewards = torch.tensor([self.reward_single(prompt, completion, name) for completion in completions],
dtype=torch.float32).to(self.device)
bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}")
return rewards
3 changes: 3 additions & 0 deletions tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_event_from_dict_all_forward_columns_match(self):
RewardModelType.nsfw.value: [1.0],
RewardModelType.reciprocate.value: [1.0],
RewardModelType.diversity.value: [1.0],
RewardModelType.dpo.value: [1.0],
RewardModelType.rlhf.value: [1.0],
RewardModelType.prompt.value: [1.0],
RewardModelType.relevance.value: [1.0],
Expand Down Expand Up @@ -100,6 +101,7 @@ def test_event_from_dict_forward_no_reward_logging(self):
assert event.nsfw_filter is None
assert event.reciprocate_reward_model is None
assert event.diversity_reward_model is None
assert event.dpo_reward_model is None
assert event.rlhf_reward_model is None
assert event.prompt_reward_model is None
assert event.relevance_filter is None
Expand Down Expand Up @@ -141,6 +143,7 @@ def test_event_from_dict_forward_reward_logging_mismatch(self):
assert event.nsfw_filter is None
assert event.reciprocate_reward_model is None
assert event.diversity_reward_model is None
assert event.dpo_reward_model is None
assert event.rlhf_reward_model is None
assert event.prompt_reward_model is None
assert event.relevance_filter is None
Expand Down