Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions openvalidators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__( self, device: str ):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained( DiversityRewardModel.diversity_model_path )
self.model = AutoModel.from_pretrained( DiversityRewardModel.diversity_model_path ).to(self.device)
self.reward_quantile = torch.tensor(0.1).to(self.device)
self.reward_bottom_k = 3
self.history_reward_bottom_k = 2
self.historic_embeddings = torch.tensor([]).to(self.device)
self.history_range = (500, 15500)
Expand Down Expand Up @@ -116,18 +116,22 @@ def regularise( rewards ):
similarity = pairwise_cosine_similarity( embeddings, self.historic_embeddings[self.history_range[0]:] )

# Reward to be at the bottom_k smallest of the 1 - similarity score.
rewards = torch.topk((1 - similarity), self.history_reward_bottom_k, largest = False)[0][:, -1]
rewards = torch.topk((1 - torch.abs(similarity)), self.history_reward_bottom_k, largest = False)[0][:, -1]

return regularise(rewards)

def get_batch_rewards( self, embeddings: torch.FloatTensor ) -> torch.FloatTensor:
def regularise( rewards ):
# sigmoid function that maps 0.07 -> 0.23; 0.1 -> 0.5; 0.2 -> 0.98
return 1/(1 + torch.exp(-40 * rewards + 4))

# Calculate the pairwise cosine similarity.
similarity = pairwise_cosine_similarity( embeddings, embeddings )

# Reward to be at the 10% quantile of the 1 - similarity score.
rewards = (1 - similarity).quantile(self.reward_quantile, dim = 1 )
rewards = torch.topk((1 - torch.abs(similarity)), self.reward_bottom_k, largest = False)[0][:, -1]

return rewards
return regularise(rewards)

def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor:
# Check if completions are empty, return 0 if so
Expand All @@ -149,4 +153,7 @@ def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch
if historic_rewards != None:
return batch_rewards * historic_rewards
else:
return batch_rewards
return batch_rewards

def normalize_rewards( self, rewards: torch.FloatTensor ) -> torch.FloatTensor:
return rewards
2 changes: 1 addition & 1 deletion openvalidators/reward/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ def reward( self, prompt: str, completion: str ) -> torch.FloatTensor:
# Calculate the pairwise cosine similarity.
similarity = pairwise_cosine_similarity( prompt_embed, embeddings )

return similarity
return torch.abs(similarity)