From 8b54ff46799fbb314f5e348ac1091e553aa32474 Mon Sep 17 00:00:00 2001 From: Eugene Date: Tue, 25 Jul 2023 17:30:30 -0700 Subject: [PATCH 1/5] updates to relevance --- openvalidators/reward/relevance.py | 140 ++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 31 deletions(-) diff --git a/openvalidators/reward/relevance.py b/openvalidators/reward/relevance.py index 56e2063..ca7357c 100644 --- a/openvalidators/reward/relevance.py +++ b/openvalidators/reward/relevance.py @@ -21,13 +21,66 @@ from .config import RewardModelType from .reward import BaseRewardModel from transformers import AutoTokenizer, AutoModel +from torchmetrics.functional import pairwise_cosine_similarity +import torch.nn.functional as F + + +def mean_pooling(model_output, attention_mask): + """Applies mean pooling to the token embeddings generated by the model. + Args: + model_output (torch.Tensor): Embedding model output, where the first element contains token embeddings. + attention_mask (torch.Tensor): Attention mask to indicate valid tokens. + Returns: + torch.Tensor: Mean-pooled representation of the token embeddings. + Notes: + - The function calculates the mean-pooled representation using the attention mask for valid tokens. + - Input_mask_expanded is created by expanding the attention mask to match the size of token embeddings. + - The result is obtained by summing the element-wise multiplication of embeddings and input_mask_expanded, + and dividing it by the sum of input_mask_expanded after clamping its values to a minimum of 1e-9. + """ + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + +class RelevanceRewardModel( BaseRewardModel ): + + @property + def name(self) -> str: return RewardModelType.relevance.value + + def __init__( self, device: str ): + super().__init__() + self.device = device + self.models = [ + BertRelevanceRewardModel(self.device), + MpnetRelevenceModel(self.device) + ] + self.bounds = [-0.0246, 0.3] + + def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor: + return torch.tensor( [self.reward( prompt, completion, name ) for completion in completions], dtype=torch.float32).to(self.device) + + def normalize_rewards( self, rewards: torch.FloatTensor ) -> torch.FloatTensor: + return rewards + + def reward(self, prompt: str, completion: str, name: str) -> float: + for i, model in enumerate(self.models): + + # rewards + diff = model.reward(prompt,completion) + + # If a model returns 0, stop iterating and return 0 + if diff < self.bounds[i]: + return 0.0 + # If none of the models returned 0, return 1 + return 1.0 class BertRelevanceRewardModel( BaseRewardModel ): relevance_model_path = "bert-base-uncased" - - @property - def name(self) -> str: return RewardModelType.relevance.value def __init__( self, device: str ): super().__init__() @@ -35,27 +88,6 @@ def __init__( self, device: str ): self.tokenizer = AutoTokenizer.from_pretrained(BertRelevanceRewardModel.relevance_model_path) self.model = AutoModel.from_pretrained(BertRelevanceRewardModel.relevance_model_path).to(self.device) - def mean_pooling(model_output, attention_mask): - """Applies mean pooling to the token embeddings generated by the model. - Args: - model_output (torch.Tensor): Embedding model output, where the first element contains token embeddings. - attention_mask (torch.Tensor): Attention mask to indicate valid tokens. - Returns: - torch.Tensor: Mean-pooled representation of the token embeddings. - Notes: - - The function calculates the mean-pooled representation using the attention mask for valid tokens. - - Input_mask_expanded is created by expanding the attention mask to match the size of token embeddings. - - The result is obtained by summing the element-wise multiplication of embeddings and input_mask_expanded, - and dividing it by the sum of input_mask_expanded after clamping its values to a minimum of 1e-9. - """ - token_embeddings = model_output[0] - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) - def get_embedding(self, message: str) -> "torch.FloatTensor": """Runs a forward pass through the model. Args: @@ -79,12 +111,12 @@ def get_embedding(self, message: str) -> "torch.FloatTensor": with torch.no_grad(): embeddings = self.model(**encoded_input) - sentence_embeddings = BertRelevanceRewardModel.mean_pooling(embeddings, encoded_input["attention_mask"]) + sentence_embeddings = mean_pooling(embeddings, encoded_input["attention_mask"]) sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) batch_representation = torch.mean(sentence_embeddings, dim=0) return batch_representation - def reward( self, prompt: str, completion:str , name: str, bound = -0.0246 ) -> float: + def reward( self, prompt: str, completion:str , name: str ) -> float: # Get the two bert embeddings. completion_embedding = self.get_embedding( completion) prompt_embedding = self.get_embedding( prompt) @@ -93,10 +125,56 @@ def reward( self, prompt: str, completion:str , name: str, bound = -0.0246 ) -> diff = (( completion_embedding - prompt_embedding )**2).mean()**0.5 # Return relevance scoring. - return 0.0 if float(-diff) < bound else 1.0 + return float(-diff) - def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor: - return torch.tensor( [self.reward( prompt, completion, name ) for completion in completions], dtype=torch.float32).to(self.device) + +class MpnetRelevenceModel( BaseRewardModel ): - def normalize_rewards( self, rewards: torch.FloatTensor ) -> torch.FloatTensor: - return rewards \ No newline at end of file + diversity_model_path = "sentence-transformers/all-mpnet-base-v2" + + def __init__( self, device: str ): + super().__init__() + self.device = device + self.tokenizer = AutoTokenizer.from_pretrained( MpnetRelevenceModel.diversity_model_path ) + self.model = AutoModel.from_pretrained( MpnetRelevenceModel.diversity_model_path ).to(self.device) + self.reward_quantile = torch.tensor(0.1).to(self.device) + + def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor": + """Runs a forward pass through the model. + Args: + sentences (:obj:`List[str]`): + text message to be encoded. + Returns: + embedding (:obj:`torch.FloatTensor`): + Embedding for the message. + """ + # Tokenizing sentences + + encoded_input = self.tokenizer( + sentences, + padding=True, + truncation=True, + return_tensors="pt", + ).to(self.device) + + # Compute token embedding + with torch.no_grad(): + embeddings = self.model(**encoded_input) + + # Pooling + sentence_embeddings = mean_pooling(embeddings, encoded_input["attention_mask"]) + + # Normalizing + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def rewards( self, prompt: str, completion: str, name: str ) -> torch.FloatTensor: + + # Get embeddings for all completions. + embeddings = self.get_embeddings( completion ) + prompt_embed = self.get_embeddings( prompt ) + + # Calculate the pairwise cosine similarity. + similarity = pairwise_cosine_similarity( prompt_embed, embeddings ) + + return similarity \ No newline at end of file From a5138eaa72b6faa40dc04c677693f3ce8112e5cc Mon Sep 17 00:00:00 2001 From: Eugene Date: Tue, 25 Jul 2023 17:37:31 -0700 Subject: [PATCH 2/5] imports --- openvalidators/neuron.py | 4 ++-- openvalidators/reward/__init__.py | 2 +- openvalidators/reward/relevance.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py index f1a2405..0534a5d 100644 --- a/openvalidators/neuron.py +++ b/openvalidators/neuron.py @@ -37,7 +37,7 @@ NSFWRewardModel, OpenAssistantRewardModel, ReciprocateRewardModel, - BertRelevanceRewardModel, + RelevanceRewardModel, MockRewardModel, DahoasRewardModel, DiversityRewardModel, @@ -187,7 +187,7 @@ def __init__(self): self.masking_functions = [ self.blacklist, - BertRelevanceRewardModel(device=self.device) + RelevanceRewardModel(device=self.device) if not self.config.neuron.relevance_off else MockRewardModel(RewardModelType.relevance.value), DiversityRewardModel(device=self.device) diff --git a/openvalidators/reward/__init__.py b/openvalidators/reward/__init__.py index 6a94469..3277b6e 100644 --- a/openvalidators/reward/__init__.py +++ b/openvalidators/reward/__init__.py @@ -2,7 +2,7 @@ from .nsfw import NSFWRewardModel from .open_assistant import OpenAssistantRewardModel from .reciprocate import ReciprocateRewardModel -from .relevance import BertRelevanceRewardModel +from .relevance import RelevanceRewardModel from .reward import BaseRewardModel from .reward import MockRewardModel from .dahoas import DahoasRewardModel diff --git a/openvalidators/reward/relevance.py b/openvalidators/reward/relevance.py index ca7357c..6fa8592 100644 --- a/openvalidators/reward/relevance.py +++ b/openvalidators/reward/relevance.py @@ -127,7 +127,6 @@ def reward( self, prompt: str, completion:str , name: str ) -> float: # Return relevance scoring. return float(-diff) - class MpnetRelevenceModel( BaseRewardModel ): diversity_model_path = "sentence-transformers/all-mpnet-base-v2" From c7ae4c6e402c6d9286d11f9389c9372253664ae0 Mon Sep 17 00:00:00 2001 From: Eugene Date: Tue, 25 Jul 2023 17:53:35 -0700 Subject: [PATCH 3/5] name --- openvalidators/reward/relevance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openvalidators/reward/relevance.py b/openvalidators/reward/relevance.py index 6fa8592..f824593 100644 --- a/openvalidators/reward/relevance.py +++ b/openvalidators/reward/relevance.py @@ -116,7 +116,7 @@ def get_embedding(self, message: str) -> "torch.FloatTensor": batch_representation = torch.mean(sentence_embeddings, dim=0) return batch_representation - def reward( self, prompt: str, completion:str , name: str ) -> float: + def reward( self, prompt: str, completion:str ) -> float: # Get the two bert embeddings. completion_embedding = self.get_embedding( completion) prompt_embedding = self.get_embedding( prompt) @@ -167,7 +167,7 @@ def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor": sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings - def rewards( self, prompt: str, completion: str, name: str ) -> torch.FloatTensor: + def rewards( self, prompt: str, completion: str ) -> torch.FloatTensor: # Get embeddings for all completions. embeddings = self.get_embeddings( completion ) From 4226453c368cbe6de8a88a9aeb3c08520140d76a Mon Sep 17 00:00:00 2001 From: Eugene Date: Tue, 25 Jul 2023 17:56:58 -0700 Subject: [PATCH 4/5] fixes --- openvalidators/reward/relevance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openvalidators/reward/relevance.py b/openvalidators/reward/relevance.py index f824593..d41d50f 100644 --- a/openvalidators/reward/relevance.py +++ b/openvalidators/reward/relevance.py @@ -167,7 +167,7 @@ def get_embeddings( self, sentences: List[str] ) -> "torch.FloatTensor": sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings - def rewards( self, prompt: str, completion: str ) -> torch.FloatTensor: + def reward( self, prompt: str, completion: str ) -> torch.FloatTensor: # Get embeddings for all completions. embeddings = self.get_embeddings( completion ) From c31b14af90feac2aafbfe8913b56e3417014e0ef Mon Sep 17 00:00:00 2001 From: Eugene Date: Wed, 26 Jul 2023 09:48:32 -0700 Subject: [PATCH 5/5] version 1.1.2 --- openvalidators/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openvalidators/__init__.py b/openvalidators/__init__.py index 8ff2f02..703e0e8 100644 --- a/openvalidators/__init__.py +++ b/openvalidators/__init__.py @@ -28,6 +28,6 @@ from . import weights from . import event -__version__ = "1.1.1" +__version__ = "1.1.2" version_split = __version__.split(".") __spec_version__ = (1000 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))