Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions openvalidators/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def add_args(cls, parser):
action="store_true",
help="Dont apply the diversity reward model",
default=False,
)
parser.add_argument(
"--neuron.task_validator_off",
action="store_true",
help="Dont apply the task validator reward model",
default=False,
)

parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions openvalidators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class EventSchema:
rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model: Optional[List[float]] # Output vector of the prompt reward model
relevance_filter: Optional[List[float]] # Output vector of the relevance scoring reward model
task_validator_filter: Optional[List[float]]

# Weights data
set_weights: Optional[List[List[float]]]
Expand All @@ -54,6 +55,7 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema':
rewards = {
'dahoas_reward_model': event_dict.get(RewardModelType.dahoas.value),
'blacklist_filter': event_dict.get(RewardModelType.blacklist.value),
'task_validator_filter': event_dict.get(RewardModelType.task_validator.value),
'nsfw_filter': event_dict.get(RewardModelType.nsfw.value),
'relevance_filter': event_dict.get(RewardModelType.relevance.value),
'reciprocate_reward_model': event_dict.get(RewardModelType.reciprocate.value),
Expand Down
34 changes: 21 additions & 13 deletions openvalidators/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# Load gating models
from openvalidators.reward import (
Blacklist,
TaskValidator,
NSFWRewardModel,
OpenAssistantRewardModel,
ReciprocateRewardModel,
Expand Down Expand Up @@ -185,24 +186,31 @@ def __init__(self):

bt.logging.error(message)
raise Exception(message)


# Masking functions
self.blacklist = (
Blacklist() if not self.config.neuron.blacklist_off else MockRewardModel(RewardModelType.blacklist.value)
)
task_validator = (
TaskValidator() if not self.config.neuron.task_validator_off
else MockRewardModel(RewardModelType.task_validator.value)
)
relevance_model = (
RelevanceRewardModel(device=self.device) if not self.config.neuron.relevance_off
else MockRewardModel(RewardModelType.relevance.value)
)
diversity_model = (
DiversityRewardModel(device=self.device) if not self.config.neuron.diversity_off
else MockRewardModel(RewardModelType.diversity.value)
)
nsfw_model = (
NSFWRewardModel(device=self.device) if not self.config.neuron.nsfw_off
else MockRewardModel(RewardModelType.nsfw.value)
)

self.masking_functions = [
self.blacklist,
RelevanceRewardModel(device=self.device)
if not self.config.neuron.relevance_off
else MockRewardModel(RewardModelType.relevance.value),
DiversityRewardModel(device=self.device)
if not self.config.neuron.diversity_off
else MockRewardModel(RewardModelType.diversity.value),
NSFWRewardModel(device=self.device)
if not self.config.neuron.nsfw_off
else MockRewardModel(RewardModelType.nsfw.value),
]
self.masking_functions = [self.blacklist, task_validator, relevance_model, diversity_model, nsfw_model]
bt.logging.debug(str(self.reward_functions))
bt.logging.debug(str(self.masking_functions))

# Init the event loop.
self.loop = asyncio.get_event_loop()
Expand Down
8 changes: 4 additions & 4 deletions openvalidators/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,16 @@ def find_unique_tags(input_text: str):

def followup_prompt( base_text:str, i:int = 0) -> str:
if i == 0:
return f"{base_text}\n\n{followup_request_template}. Do not return an answer:\n"
return f"{base_text}\n\n{followup_request_template}\n. Do not try to return an answer or a summary:"
else:
return f"{base_text}\n\n{followup_request_template} and previous questions. Do not return an answer:\n"
return f"{base_text}\n\n{followup_request_template} and previous questions. Do not try to return an answer or a summary:\n"


def answer_prompt( base_text:str, followup:str ) -> str:
return f"{base_text}\n\nQuestion: {followup}\nAnswer the last question step by step and explain your thoughts:\n"
return f"{base_text}\n\nQuestion:{followup}\nAnswer the question step by step and explain your thoughts. Do not include questions or summaries in your answer."

augment_request_template = "Summarize the preceding context"

def augment_prompt( base_text:str ) -> str:
random_level = random.randint(4, 8)
return f"{base_text}\n\n{augment_request_template} in {random_level} sentences.\n"
return f"{base_text}\n\n{augment_request_template} in {random_level} sentences. Do not try to create questions or answers for your summarization.\n\n"
3 changes: 2 additions & 1 deletion openvalidators/reward/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .blacklist import Blacklist
from .task_validator import TaskValidator
from .nsfw import NSFWRewardModel
from .open_assistant import OpenAssistantRewardModel
from .reciprocate import ReciprocateRewardModel
Expand All @@ -8,4 +9,4 @@
from .dahoas import DahoasRewardModel
from .diversity import DiversityRewardModel
from .prompt import PromptRewardModel
from .config import RewardModelType, DefaultRewardFrameworkConfig
from .config import RewardModelType, DefaultRewardFrameworkConfig
1 change: 1 addition & 0 deletions openvalidators/reward/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RewardModelType(Enum):
blacklist = 'blacklist_filter'
nsfw = 'nsfw_filter'
relevance = 'relevance_filter'
task_validator = 'task_validator_filter'


@dataclass(frozen=True)
Expand Down
63 changes: 63 additions & 0 deletions openvalidators/reward/task_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# The MIT License (MIT)
# Copyright © 2021 Yuma Rao

# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.

# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import torch
from typing import List
from .config import RewardModelType
from .reward import BaseRewardModel


class TaskValidator( BaseRewardModel ):

@property
def name(self) -> str: return RewardModelType.task_validator.value

def __init__(self):
super().__init__()

def reward( self, prompt: str, completion: str, name: str ) -> float:
summary_keywords = ['Summary:']
question_keywords = ['Question:']
answer_keywords = ['Answer:']

completion_contains_answer = any(answer_keyword.lower() in completion.lower() for answer_keyword in answer_keywords)
completion_contains_question = any(question_keyword.lower() in completion.lower() for question_keyword in question_keywords)
completion_contains_summary = any(summary_keyword.lower() in completion.lower() for summary_keyword in summary_keywords)

is_summarization_prompt = name == 'augment'
is_question_prompt = name.startswith('followup')
is_answer_prompt = name.startswith('answer')

if (is_summarization_prompt or is_question_prompt) and completion_contains_answer:
return 0.0

if (is_summarization_prompt or is_answer_prompt) and completion_contains_question:
return 0.0

if not is_summarization_prompt and completion_contains_summary:
return 0.0

return 1

def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch.FloatTensor:
return torch.tensor( [self.reward( prompt, completion, name ) for completion in completions], dtype=torch.float32)

def normalize_rewards( self, rewards: torch.FloatTensor ) -> torch.FloatTensor:
return rewards

def reset(self):
pass

Empty file added tests/reward/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions tests/reward/test_task_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest
from openvalidators.reward.task_validator import TaskValidator

class TaskValidatorTestCase(unittest.TestCase):
"""
This class contains unit tests for the TaskValidator class.

The tests cover different scenarios for the `reward` method of the TaskValidator class.
The `reward` method is expected to return a reward based on the task name and the completion text.
"""

def setUp(self):
self.validator = TaskValidator()

def test_augment_with_answer_keyword(self):
"""
Test if the reward method returns 0 when the task "name" starts with 'augment' (summarization)
and the completion contains the 'Answer:' keyword.
"""
name = f'augment'
completion = "Summary: test summary\nAnswer: Test answer"
self.assertEqual(self.validator.reward('', completion, name), 0.0)

def test_followup_with_answer_keyword(self):
"""
Test if the reward method returns 0 when the task "name" starts with 'followup' (question generation)
and the completion contains the 'Answer:' keyword.
"""
for i in range(0, 4):
name = f'followup{i}'
completion = 'Question: This is a test question?\nAnswer: This is a test answer.'
self.assertEqual(self.validator.reward('', completion, name), 0.0)

def test_augment_with_question_keyword(self):
"""
Test if the reward method returns 0 when the task "name" starts with 'augment' (summarization)
and the completion contains the 'Question:' keyword.
"""
name = f'augment'
completion = "Summary: test summary\nQuestion: This is a test question?"
self.assertEqual(self.validator.reward('', completion, name), 0.0)

def test_answer_with_question_keyword(self):
"""
Test if the reward method returns 0 when the task "name" is 'answer' (answer generation)
and the completion contains the 'Question:' keyword.
"""
for i in range(0, 4):
name = f'answer{i}'
completion = 'Question: This is a test question?\nAnswer: This is a test answer.'
self.assertEqual(self.validator.reward('', completion, name), 0.0)

def test_followup_and_answer_with_summary_keyword(self):
"""
Test if the reward method returns 0 when the task "name" is different from "augment" (summarization)
and the completion contains the 'Summary:' keyword.
"""
for name in ['followup0', 'followup1', 'followup2', 'followup3', 'answer0', 'answer1', 'answer2', 'answer3']:
completion = 'Summary: This is a test summary.'
self.assertEqual(self.validator.reward('', completion, name), 0.0)

def test_reward_valid_followup(self):
"""
Test if the reward method returns 1 when the task "name" starts with 'followup' (question generation)
and the completion contains a question
"""
for i in range(0, 4):
name = f'followup{i}'
completion = 'Question: This is a test question?'
self.assertEqual(self.validator.reward('', completion, name), 1.0)

def test_reward_valid_answer(self):
"""
Test if the reward method returns 1 when the task "name" is 'answer' (answer generation)
and the completion contains an answer
"""
for i in range(0, 4):
name = f'answer{i}'
completion = 'Answer: This is a test answer.'
self.assertEqual(self.validator.reward('', completion, name), 1.0)

def test_reward_valid_augment(self):
"""
Test if the reward method returns 1 when the task "name" is 'augment' (summarization)
and the completion contains the a summary.
"""
name = 'augment'
completion = 'Summary: This is a test summary.'
self.assertEqual(self.validator.reward('', completion, name), 1.0)

def test_reward_valid_other(self):
"""
Test if the reward method returns 1 when the task "name" is different from "augment", "followup", and "answer"
and the completion does not contain the 'Summary:', 'Answer:', and 'Question:' keywords.
"""
for name in ['followup0', 'followup1', 'followup2', 'followup3', 'answer0', 'answer1', 'answer2', 'answer3']:
completion = 'This is a test completion.'
self.assertEqual(self.validator.reward('', completion, name), 1.0)

if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_event_from_dict_all_forward_columns_match(self):
RewardModelType.rlhf.value: [1.0],
RewardModelType.prompt.value: [1.0],
RewardModelType.relevance.value: [1.0],
RewardModelType.task_validator.value: [1.0]
}

# Act
Expand Down Expand Up @@ -102,6 +103,7 @@ def test_event_from_dict_forward_no_reward_logging(self):
assert event.rlhf_reward_model is None
assert event.prompt_reward_model is None
assert event.relevance_filter is None
assert event.task_validator_filter is None

def test_event_from_dict_forward_reward_logging_mismatch(self):
"""Test that all default columns logged on the forward pass are correctly converted and that
Expand Down Expand Up @@ -142,4 +144,5 @@ def test_event_from_dict_forward_reward_logging_mismatch(self):
assert event.rlhf_reward_model is None
assert event.prompt_reward_model is None
assert event.relevance_filter is None
assert event.task_validator_filter is None