Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions openvalidators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__( self, device: str ):
self.tokenizer = AutoTokenizer.from_pretrained( DiversityRewardModel.diversity_model_path )
self.model = AutoModel.from_pretrained( DiversityRewardModel.diversity_model_path ).to(self.device)
self.reward_quantile = torch.tensor(0.1).to(self.device)
self.history_reward_bottom_k = 5
self.historic_embeddings = torch.tensor([]).to(self.device)
self.history_range = (500, 15500)

def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor":
"""Runs a forward pass through the model.
Expand Down Expand Up @@ -86,20 +89,64 @@ def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor":
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings

def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor:
def update_historic_embeddings( self, embeddings: torch.FloatTensor ):
def unique(embeddings):
unique_embeddings = [embeddings[0]]
last_emb = embeddings[0]
for emb in embeddings:
if not torch.all(torch.eq(emb, last_emb)):
unique_embeddings.append(emb)
last_emb = emb
return torch.stack(unique_embeddings)

embeddings_unique = unique(embeddings)
historic_embeddings = torch.cat([self.historic_embeddings, embeddings_unique])
self.historic_embeddings = historic_embeddings[-self.history_range[1]:, :]

def get_historic_rewards( self, embeddings: torch.FloatTensor ) -> torch.FloatTensor:
def regularise( rewards ):
# sigmoid function that cutoff at 0.05 approximately
return 1/(1 + torch.exp(-1000 * rewards + 50))

# Return None if history size is too small
if self.historic_embeddings.shape[0] < self.history_range[0]:
return None

# Calculate the pairwise cosine similarity.
similarity = pairwise_cosine_similarity( embeddings, self.historic_embeddings[self.history_range[0]:] )

# Reward to be at the 10% quantile of the 1 - similarity score.
rewards = torch.topk((1 - similarity), self.history_reward_bottom_k, largest = False)[0][:, -1]

return regularise(rewards)

def get_batch_rewards( self, embeddings: torch.FloatTensor ) -> torch.FloatTensor:
# Calculate the pairwise cosine similarity.
similarity = pairwise_cosine_similarity( embeddings, embeddings )

# Reward to be at the 10% quantile of the 1 - similarity score.
rewards = (1 - similarity).quantile(self.reward_quantile, dim = 1 )

return rewards

def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor:
# Check if completions are empty, return 0 if so
if len(completions) == 0:
return torch.tensor([]).to(self.device)

# Get embeddings for all completions.
embeddings = self.get_embeddings( completions )

# Calculate the pairwise cosine similarity.
similarity = pairwise_cosine_similarity( embeddings, embeddings )
# Get batch rewards.
batch_rewards = self.get_batch_rewards(embeddings)

# Reward to be at the 10% quantile of the 1 - similarity score.
rewards = (1 - similarity).quantile(self.reward_quantile, dim = 1 )
# get historic rewards.
historic_rewards = self.get_historic_rewards(embeddings)

self.update_historic_embeddings(embeddings)

# Return all
return rewards
if historic_rewards != None:
return batch_rewards * historic_rewards
else:
return batch_rewards