From 32fc34e7f9c18dd0fc4646efd9577e2927dac7bb Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Fri, 4 Aug 2023 10:39:44 -0400 Subject: [PATCH] complements task validator keywords --- openvalidators/reward/config.py | 4 +--- openvalidators/reward/task_validator.py | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/openvalidators/reward/config.py b/openvalidators/reward/config.py index 86354cb..d3b8376 100644 --- a/openvalidators/reward/config.py +++ b/openvalidators/reward/config.py @@ -1,7 +1,5 @@ # The MIT License (MIT) # Copyright © 2021 Yuma Rao -from dataclasses import dataclass - # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, @@ -15,7 +13,7 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. - +from dataclasses import dataclass from enum import Enum diff --git a/openvalidators/reward/task_validator.py b/openvalidators/reward/task_validator.py index 6b7a1af..e9dfb77 100644 --- a/openvalidators/reward/task_validator.py +++ b/openvalidators/reward/task_validator.py @@ -29,9 +29,9 @@ def __init__(self): super().__init__() def reward( self, prompt: str, completion: str, name: str ) -> float: - summary_keywords = ['Summary:'] - question_keywords = ['Question:'] - answer_keywords = ['Answer:'] + summary_keywords = ['Summary:', 'Paraphrase:', 'Paraphrasing:', 'Paraphrased:'] + question_keywords = ['Question:', 'Query:', 'Q:'] + answer_keywords = ['Answer:', 'Response:', 'A:', 'Completion:'] completion_contains_answer = any(answer_keyword.lower() in completion.lower() for answer_keyword in answer_keywords) completion_contains_question = any(question_keyword.lower() in completion.lower() for question_keyword in question_keywords)