diff --git a/openvalidators/reward/diversity.py b/openvalidators/reward/diversity.py index 185c53b..8b24ab1 100644 --- a/openvalidators/reward/diversity.py +++ b/openvalidators/reward/diversity.py @@ -56,6 +56,9 @@ def __init__( self, device: str ): self.tokenizer = AutoTokenizer.from_pretrained( DiversityRewardModel.diversity_model_path ) self.model = AutoModel.from_pretrained( DiversityRewardModel.diversity_model_path ).to(self.device) self.reward_quantile = torch.tensor(0.1).to(self.device) + self.history_reward_bottom_k = 5 + self.historic_embeddings = torch.tensor([]).to(self.device) + self.history_range = (500, 15500) def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor": """Runs a forward pass through the model. @@ -86,8 +89,47 @@ def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor": sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings - def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor: + def update_historic_embeddings( self, embeddings: torch.FloatTensor ): + def unique(embeddings): + unique_embeddings = [embeddings[0]] + last_emb = embeddings[0] + for emb in embeddings: + if not torch.all(torch.eq(emb, last_emb)): + unique_embeddings.append(emb) + last_emb = emb + return torch.stack(unique_embeddings) + + embeddings_unique = unique(embeddings) + historic_embeddings = torch.cat([self.historic_embeddings, embeddings_unique]) + self.historic_embeddings = historic_embeddings[-self.history_range[1]:, :] + + def get_historic_rewards( self, embeddings: torch.FloatTensor ) -> torch.FloatTensor: + def regularise( rewards ): + # sigmoid function that cutoff at 0.05 approximately + return 1/(1 + torch.exp(-1000 * rewards + 50)) + + # Return None if history size is too small + if self.historic_embeddings.shape[0] < self.history_range[0]: + return None + + # Calculate the pairwise cosine similarity. + similarity = pairwise_cosine_similarity( embeddings, self.historic_embeddings[self.history_range[0]:] ) + + # Reward to be at the 10% quantile of the 1 - similarity score. + rewards = torch.topk((1 - similarity), self.history_reward_bottom_k, largest = False)[0][:, -1] + + return regularise(rewards) + + def get_batch_rewards( self, embeddings: torch.FloatTensor ) -> torch.FloatTensor: + # Calculate the pairwise cosine similarity. + similarity = pairwise_cosine_similarity( embeddings, embeddings ) + # Reward to be at the 10% quantile of the 1 - similarity score. + rewards = (1 - similarity).quantile(self.reward_quantile, dim = 1 ) + + return rewards + + def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor: # Check if completions are empty, return 0 if so if len(completions) == 0: return torch.tensor([]).to(self.device) @@ -95,11 +137,16 @@ def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch # Get embeddings for all completions. embeddings = self.get_embeddings( completions ) - # Calculate the pairwise cosine similarity. - similarity = pairwise_cosine_similarity( embeddings, embeddings ) + # Get batch rewards. + batch_rewards = self.get_batch_rewards(embeddings) - # Reward to be at the 10% quantile of the 1 - similarity score. - rewards = (1 - similarity).quantile(self.reward_quantile, dim = 1 ) + # get historic rewards. + historic_rewards = self.get_historic_rewards(embeddings) + self.update_historic_embeddings(embeddings) + # Return all - return rewards \ No newline at end of file + if historic_rewards != None: + return batch_rewards * historic_rewards + else: + return batch_rewards \ No newline at end of file