Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion openvalidators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,25 @@ class EventSchema:
relevance_filter: Optional[List[float]] # Output vector of the relevance scoring reward model
task_validator_filter: Optional[List[float]]

dahoas_reward_model_normalized: Optional[List[float]] # Output vector of the dahoas reward model
nsfw_filter_normalized: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model_normalized: Optional[List[float]] # Output vector of the reciprocate reward model
diversity_reward_model_normalized: Optional[List[float]] # Output vector of the diversity reward model
dpo_reward_model_normalized: Optional[List[float]] # Output vector of the dpo reward model
rlhf_reward_model_normalized: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model_normalized: Optional[List[float]] # Output vector of the prompt reward model
relevance_filter_normalized: Optional[List[float]] # Output vector of the relevance scoring reward model
task_validator_filter_normalized: Optional[List[float]]

# Weights data
set_weights: Optional[List[List[float]]]

@staticmethod
def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema':
"""Converts a dictionary to an EventSchema object."""
rewards = {
'dahoas_reward_model': event_dict.get(RewardModelType.dahoas.value),
'blacklist_filter': event_dict.get(RewardModelType.blacklist.value),
'dahoas_reward_model': event_dict.get(RewardModelType.dahoas.value),
'task_validator_filter': event_dict.get(RewardModelType.task_validator.value),
'nsfw_filter': event_dict.get(RewardModelType.nsfw.value),
'relevance_filter': event_dict.get(RewardModelType.relevance.value),
Expand All @@ -64,6 +74,16 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema':
'dpo_reward_model': event_dict.get(RewardModelType.dpo.value),
'rlhf_reward_model': event_dict.get(RewardModelType.rlhf.value),
'prompt_reward_model': event_dict.get(RewardModelType.prompt.value),

'dahoas_reward_model_normalized': event_dict.get(RewardModelType.dahoas.value + '_normalized'),
'task_validator_filter_normalized': event_dict.get(RewardModelType.task_validator.value + '_normalized'),
'nsfw_filter_normalized': event_dict.get(RewardModelType.nsfw.value + '_normalized'),
'relevance_filter_normalized': event_dict.get(RewardModelType.relevance.value + '_normalized'),
'reciprocate_reward_model_normalized': event_dict.get(RewardModelType.reciprocate.value + '_normalized'),
'diversity_reward_model_normalized': event_dict.get(RewardModelType.diversity.value + '_normalized'),
'dpo_reward_model_normalized': event_dict.get(RewardModelType.dpo.value + '_normalized'),
'rlhf_reward_model_normalized': event_dict.get(RewardModelType.rlhf.value + '_normalized'),
'prompt_reward_model_normalized': event_dict.get(RewardModelType.prompt.value + '_normalized'),
}

# Logs warning that expected data was not set properly
Expand Down
14 changes: 8 additions & 6 deletions openvalidators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,20 @@ async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude
# Compute the rewards for the responses given the prompt.
rewards: torch.FloatTensor = torch.zeros(len(responses), dtype=torch.float32).to(self.device)
for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions):
reward_i = reward_fn_i.apply(prompt, responses, name).to(self.device)
rewards += weight_i * reward_i
reward_i, reward_i_normalized = reward_fn_i.apply(prompt, responses, name)
rewards += weight_i * reward_i_normalized.to(self.device)
if not self.config.neuron.disable_log_rewards:
event[reward_fn_i.name] = reward_i.tolist()
bt.logging.trace(str(reward_fn_i.name), reward_i.tolist())
event[reward_fn_i.name + '_normalized'] = reward_i_normalized.tolist()
bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist())

for masking_fn_i in self.masking_functions:
mask_i = masking_fn_i.apply(base_prompt, responses, name).to(self.device)
rewards *= mask_i # includes diversity
mask_i, mask_i_normalized = masking_fn_i.apply(base_prompt, responses, name)
rewards *= mask_i_normalized.to(self.device) # includes diversity
if not self.config.neuron.disable_log_rewards:
event[masking_fn_i.name] = mask_i.tolist()
bt.logging.trace(str(masking_fn_i.name), mask_i.tolist())
event[masking_fn_i.name + '_normalized'] = mask_i_normalized.tolist()
bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist())

# Train the gating model based on the predicted scores and the actual rewards.
gating_scores: torch.FloatTensor = self.gating_model(prompt).to(self.device)
Expand Down
14 changes: 8 additions & 6 deletions openvalidators/reward/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,19 @@ def apply( self, prompt: str, responses: List[ bt.DendriteCall ], name: str) ->
successful_rewards = self.get_rewards( prompt, successful_completions, name )

# Softmax rewards across samples.
successful_rewards = self.normalize_rewards( successful_rewards )
successful_rewards_normalized = self.normalize_rewards( successful_rewards )

# Init zero rewards for all calls.
filled_rewards = torch.zeros( len( responses ), dtype=torch.float32)
filled_rewards = torch.ones( len( responses ), dtype=torch.float32) * torch.nan
filled_rewards_normalized = torch.zeros( len( responses ), dtype=torch.float32)

# Fill reward tensor.
for idx, reward in zip(successful_completions_indices, successful_rewards):
for idx, reward, reward_normalized in zip(successful_completions_indices, successful_rewards, successful_rewards_normalized):
filled_rewards[idx] = reward
filled_rewards_normalized[idx] = reward_normalized

# Return the filled rewards.
return filled_rewards
return filled_rewards, filled_rewards_normalized


class MockRewardModel( BaseRewardModel ):
Expand All @@ -121,7 +123,7 @@ def __init__(self, mock_name: str = 'MockReward'):
self.mock_name = mock_name

def apply( self, prompt: str, completion: List[str], name: str ) -> torch.FloatTensor:
return torch.tensor( [0 for _ in completion], dtype=torch.float32 )

mock_reward = torch.tensor( [0 for _ in completion], dtype=torch.float32 )
return mock_reward, mock_reward


39 changes: 37 additions & 2 deletions tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ def test_event_from_dict_all_forward_columns_match(self):
RewardModelType.rlhf.value: [1.0],
RewardModelType.prompt.value: [1.0],
RewardModelType.relevance.value: [1.0],
RewardModelType.task_validator.value: [1.0]
RewardModelType.task_validator.value: [1.0],

RewardModelType.dahoas.value + '_normalized': [1.0],
RewardModelType.blacklist.value + '_normalized': [1.0],
RewardModelType.nsfw.value + '_normalized': [1.0],
RewardModelType.reciprocate.value + '_normalized': [1.0],
RewardModelType.diversity.value + '_normalized': [1.0],
RewardModelType.dpo.value + '_normalized': [1.0],
RewardModelType.rlhf.value + '_normalized': [1.0],
RewardModelType.prompt.value + '_normalized': [1.0],
RewardModelType.relevance.value + '_normalized': [1.0],
RewardModelType.task_validator.value + '_normalized': [1.0]
}

# Act
Expand Down Expand Up @@ -107,6 +118,16 @@ def test_event_from_dict_forward_no_reward_logging(self):
assert event.relevance_filter is None
assert event.task_validator_filter is None

assert event.dahoas_reward_model_normalized is None
assert event.nsfw_filter_normalized is None
assert event.reciprocate_reward_model_normalized is None
assert event.diversity_reward_model_normalized is None
assert event.dpo_reward_model_normalized is None
assert event.rlhf_reward_model_normalized is None
assert event.prompt_reward_model_normalized is None
assert event.relevance_filter_normalized is None
assert event.task_validator_filter_normalized is None

def test_event_from_dict_forward_reward_logging_mismatch(self):
"""Test that all default columns logged on the forward pass are correctly converted and that
that reward columns that should be logged are logged as warnings"""
Expand All @@ -124,7 +145,12 @@ def test_event_from_dict_forward_reward_logging_mismatch(self):
'rewards': [1.0],
}

not_logged_columns = [field.value for field in RewardModelType]
not_logged_columns = []
for field in RewardModelType:
not_logged_columns.append(field.value)
if field.value != 'blacklist_filter':
not_logged_columns.append(field.value + '_normalized')


# Act
with patch('bittensor.logging.warning') as mock_warning:
Expand All @@ -149,3 +175,12 @@ def test_event_from_dict_forward_reward_logging_mismatch(self):
assert event.relevance_filter is None
assert event.task_validator_filter is None

assert event.dahoas_reward_model_normalized is None
assert event.nsfw_filter_normalized is None
assert event.reciprocate_reward_model_normalized is None
assert event.diversity_reward_model_normalized is None
assert event.dpo_reward_model_normalized is None
assert event.rlhf_reward_model_normalized is None
assert event.prompt_reward_model_normalized is None
assert event.relevance_filter_normalized is None
assert event.task_validator_filter_normalized is None