Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def load_curated_bias(
original_context=item.get("original_context", "-"),
perturbed_question=item["perturbed_question"],
perturbed_context=item.get("perturbed_context", "-"),
task="question-answering",
test_type=item["test_type"],
category=item["category"],
dataset_name="BoolQ",
Expand All @@ -186,7 +185,6 @@ def load_curated_bias(
SummarizationSample(
original=item["original"],
test_case=item["test_case"],
task="summarization",
test_type=item["test_type"],
category=item["category"],
dataset_name="XSum",
Expand Down Expand Up @@ -805,7 +803,6 @@ def load_data(self) -> List[Sample]:
self.column_matcher["context"], "-"
),
expected_results=expected_results,
task=self.task,
dataset_name=self._file_path.split("/")[-2],
)
)
Expand All @@ -820,15 +817,13 @@ def load_data(self) -> List[Sample]:
SummarizationSample(
original=item[self.column_matcher["text"]],
expected_results=expected_results,
task=self.task,
dataset_name=self._file_path.split("/")[-2],
)
)
elif self.task == "toxicity":
data.append(
ToxicitySample(
prompt=item[self.column_matcher["text"]],
task=self.task,
dataset_name=self._file_path.split("/")[-2],
)
)
Expand All @@ -837,7 +832,6 @@ def load_data(self) -> List[Sample]:
data.append(
TranslationSample(
original=item[self.column_matcher["text"]],
task=self.task,
dataset_name=self._file_path.split("/")[-2],
)
)
Expand Down Expand Up @@ -1075,9 +1069,7 @@ def _row_to_sample_summarization(data_row: Dict[str, str]) -> Sample:
original = data_row.get("document", "")
summary = data_row.get("summary", "")

return SummarizationSample(
original=original, expected_results=summary, task="summarization"
)
return SummarizationSample(original=original, expected_results=summary)

def export_data(self, data: List[Sample], output_path: str):
"""Exports the data to the corresponding format and saves it to 'output_path'.
Expand Down
66 changes: 44 additions & 22 deletions langtest/transform/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class BaseRepresentation(ABC):
"text-classification",
"question-answering",
"summarization",
"toxicity",
"translation",
]

@staticmethod
@classmethod
@abstractmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Abstract method that implements the representation measure.

Expand All @@ -53,10 +55,10 @@ def transform(
"""
raise NotImplementedError()

@staticmethod
@classmethod
@abstractmethod
async def run(
sample_list: List[Sample], model: ModelFactory, **kwargs
cls, sample_list: List[Sample], model: ModelFactory, **kwargs
) -> List[Sample]:
"""Computes the score for the given data.

Expand Down Expand Up @@ -97,9 +99,9 @@ class GenderRepresentation(BaseRepresentation):
"min_gender_representation_proportion",
]

@staticmethod
@classmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Compute the gender representation measure

Expand All @@ -114,6 +116,10 @@ def transform(
Returns:
Union[List[MinScoreQASample], List[MinScoreSample]]: Gender Representation test results.
"""
assert (
test in cls.alias_name
), f"Parameter 'test' should be in: {cls.alias_name}, got '{test}'"

samples = []
if test == "min_gender_representation_count":
if isinstance(params["min_count"], dict):
Expand Down Expand Up @@ -148,7 +154,7 @@ def transform(
expected_results=MinScoreOutput(min_score=value),
)
samples.append(sample)
elif test == "min_gender_representation_proportion":
else:
min_proportions = {"male": 0.26, "female": 0.26, "unknown": 0.26}

if isinstance(params["min_proportion"], dict):
Expand Down Expand Up @@ -240,7 +246,7 @@ async def run(

elif sample.test_type == "min_gender_representation_count":
sample.actual_results = MinScoreOutput(
min_score=round(gender_counts[sample.test_case], 2)
min_score=gender_counts[sample.test_case]
)
sample.state = "done"
return sample_list
Expand All @@ -259,9 +265,9 @@ class EthnicityRepresentation(BaseRepresentation):
"min_ethnicity_name_representation_proportion",
]

@staticmethod
@classmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Compute the ethnicity representation measure

Expand All @@ -276,8 +282,11 @@ def transform(
Returns:
Union[List[MinScoreQASample], List[MinScoreSample]]: Ethnicity Representation test results.
"""
sample_list = []
assert (
test in cls.alias_name
), f"Parameter 'test' should be in: {cls.alias_name}, got '{test}'"

sample_list = []
if test == "min_ethnicity_name_representation_count":
if not params:
expected_representation = {
Expand Down Expand Up @@ -323,7 +332,7 @@ def transform(
)
sample_list.append(sample)

if test == "min_ethnicity_name_representation_proportion":
else:
if not params:
expected_representation = {
"black": 0.13,
Expand Down Expand Up @@ -447,9 +456,9 @@ class LabelRepresentation(BaseRepresentation):

supported_tasks = ["ner", "text-classification"]

@staticmethod
@classmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Compute the label representation measure

Expand All @@ -464,6 +473,10 @@ def transform(
Returns:
Union[List[MinScoreQASample], List[MinScoreSample]]: Label Representation test results.
"""
assert (
test in cls.alias_name
), f"Parameter 'test' should be in: {cls.alias_name}, got '{test}'"

sample_list = []
labels = [s.expected_results.predictions for s in data]
if isinstance(data[0].expected_results, NEROutput):
Expand Down Expand Up @@ -493,7 +506,7 @@ def transform(
)
sample_list.append(sample)

if test == "min_label_representation_proportion":
else:
if not params:
expected_representation = {k: (1 / len(k)) * 0.8 for k in labels}

Expand Down Expand Up @@ -587,10 +600,16 @@ class ReligionRepresentation(BaseRepresentation):
"min_religion_name_representation_count",
"min_religion_name_representation_proportion",
]
supported_tasks = [
"ner",
"text-classification",
"question-answering",
"summarization",
]

@staticmethod
@classmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Compute the religion representation measure

Expand All @@ -605,8 +624,11 @@ def transform(
Returns:
Union[List[MinScoreQASample], List[MinScoreSample]]: Religion Representation test results.
"""
sample_list = []
assert (
test in cls.alias_name
), f"Parameter 'test' should be in: {cls.alias_name}, got '{test}'"

sample_list = []
if test == "min_religion_name_representation_count":
if not params:
expected_representation = {
Expand Down Expand Up @@ -652,7 +674,7 @@ def transform(
)
sample_list.append(sample)

if test == "min_religion_name_representation_proportion":
else:
if not params:
expected_representation = {
"muslim": 0.11,
Expand Down Expand Up @@ -797,9 +819,9 @@ class CountryEconomicRepresentation(BaseRepresentation):
"min_country_economic_representation_proportion",
]

@staticmethod
@classmethod
def transform(
test: str, data: List[Sample], params: Dict
cls, test: str, data: List[Sample], params: Dict
) -> Union[List[MinScoreQASample], List[MinScoreSample]]:
"""Compute the country economic representation measure

Expand Down Expand Up @@ -858,7 +880,7 @@ def transform(
)
sample_list.append(sample)

if test == "min_country_economic_representation_proportion":
else:
if not params:
expected_representation = {
"high_income": 0.20,
Expand Down
Loading