Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pages/docs/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Supported `data_source` formats are task-dependent. The following table provides
| ----------------------- | -------------------------------------------------------- |
| **ner** | CoNLL, CSV and HuggingFace Datasets |
| **text-classification** | CSV and HuggingFace Datsets |
| **question-answering** | Select list of benchmark datasets or HuggingFace Datsets |
| **question-answering** | Select list of benchmark datasets |
| **summarization** | Select list of benchmark datasets or HuggingFace Datsets |
| **toxicity** | Select list of benchmark datasets |
| **clinical-tests** | Select list of curated datasets |
Expand Down
28 changes: 18 additions & 10 deletions langtest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def transform(self) -> List[Sample]:
A list of `Sample` objects representing the resulting dataset after running the robustness test.
"""
all_samples = []
no_transformation_applied_tests = set()
no_transformation_applied_tests = {}
tests_copy = self.tests.copy()
for test_name, params in tests_copy.items():
if TestFactory.is_augment:
Expand Down Expand Up @@ -505,14 +505,18 @@ def transform(self) -> List[Sample]:
new_transformed_samples, removed_samples_tests = filter_unique_samples(
TestFactory.task, transformed_samples, test_name
)
no_transformation_applied_tests.update(removed_samples_tests)
all_samples.extend(new_transformed_samples)

no_transformation_applied_tests.update(removed_samples_tests)

if no_transformation_applied_tests:
logging.warning(
"Removing samples where no transformation has been applied in the following tests: "
+ ", ".join(no_transformation_applied_tests)
warning_message = (
"Removing samples where no transformation has been applied:\n"
)
for test, count in no_transformation_applied_tests.items():
warning_message += f"- Test '{test}': {count} samples removed out of {len(self._data_handler)}\n"

logging.warning(warning_message)

return all_samples

Expand Down Expand Up @@ -689,7 +693,7 @@ def transform(self) -> List[Sample]:
A list of `Sample` objects representing the resulting dataset after running the bias test.
"""
all_samples = []
no_transformation_applied_tests = set()
no_transformation_applied_tests = {}
for test_name, params in self.tests.items():
data_handler_copy = [x.copy() for x in self._data_handler]

Expand All @@ -700,14 +704,18 @@ def transform(self) -> List[Sample]:
new_transformed_samples, removed_samples_tests = filter_unique_samples(
TestFactory.task, transformed_samples, test_name
)
no_transformation_applied_tests.update(removed_samples_tests)
all_samples.extend(new_transformed_samples)

no_transformation_applied_tests.update(removed_samples_tests)

if no_transformation_applied_tests:
logging.warning(
"Removing samples where no transformation has been applied in the following tests: "
+ ", ".join(no_transformation_applied_tests)
warning_message = (
"Removing samples where no transformation has been applied:\n"
)
for test, count in no_transformation_applied_tests.items():
warning_message += f"- Test '{test}': {count} samples removed out of {len(self._data_handler)}\n"

logging.warning(warning_message)

return all_samples

Expand Down
26 changes: 19 additions & 7 deletions langtest/transform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,10 @@ def filter_unique_samples(task: str, transformed_samples: list, test_name: str):

Returns:
new_transformed_samples (list): List of filtered samples with unique transformations.
no_transformation_applied_tests (set): Set of test names for which no transformation
was applied due to non-uniqueness.
no_transformation_applied_tests (dict): A dictionary where keys are test names and
values are the number of samples removed from each test.
"""
no_transformation_applied_tests = set()
no_transformation_applied_tests = {}
new_transformed_samples = []
if task == "question-answering":
for sample in transformed_samples:
Expand All @@ -383,9 +383,15 @@ def filter_unique_samples(task: str, transformed_samples: list, test_name: str):
new_transformed_samples.append(sample)
else:
if test_name == "multiple_perturbations":
no_transformation_applied_tests.add(sample.test_type)
if sample.test_type in no_transformation_applied_tests:
no_transformation_applied_tests[sample.test_type] += 1
else:
no_transformation_applied_tests[sample.test_type] = 1
else:
no_transformation_applied_tests.add(test_name)
if test_name in no_transformation_applied_tests:
no_transformation_applied_tests[test_name] += 1
else:
no_transformation_applied_tests[test_name] = 1
else:
for sample in transformed_samples:
if sample.original.replace(" ", "") != sample.test_case.replace(" ", ""):
Expand All @@ -394,8 +400,14 @@ def filter_unique_samples(task: str, transformed_samples: list, test_name: str):
new_transformed_samples.append(sample)
else:
if test_name == "multiple_perturbations":
no_transformation_applied_tests.add(sample.test_type)
if sample.test_type in no_transformation_applied_tests:
no_transformation_applied_tests[sample.test_type] += 1
else:
no_transformation_applied_tests[sample.test_type] = 1
else:
no_transformation_applied_tests.add(test_name)
if test_name in no_transformation_applied_tests:
no_transformation_applied_tests[test_name] += 1
else:
no_transformation_applied_tests[test_name] = 1

return new_transformed_samples, no_transformation_applied_tests