Skip to content
2 changes: 2 additions & 0 deletions langtest/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def fix(
res = TestFactory.transform(
self.task, [hash_map[each]], test_type
)
if len(res) == 0:
continue
hash_map[each] = res[0]
else:
if test == "swap_entities":
Expand Down
36 changes: 27 additions & 9 deletions langtest/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
import time
import logging
from collections import defaultdict
from abc import ABC, abstractmethod
from typing import Dict, List, Union
Expand Down Expand Up @@ -33,7 +34,7 @@
religion_wise_names,
white_names,
)
from .utils import get_substitution_names, create_terminology
from .utils import get_substitution_names, create_terminology, filter_unique_samples
from ..modelhandler import ModelFactory
from ..utils.custom_types.sample import (
NERSample,
Expand Down Expand Up @@ -365,7 +366,8 @@ def transform(self) -> List[Sample]:
A list of `Sample` objects representing the resulting dataset after running the robustness test.
"""
all_samples = []
tests_copy = self.tests.copy() # Create a copy of self.tests
no_transformation_applied_tests = set()
tests_copy = self.tests.copy()
for test_name, params in tests_copy.items():
if TestFactory.is_augment:
data_handler_copy = [x.copy() for x in self._data_handler]
Expand Down Expand Up @@ -495,11 +497,18 @@ def transform(self) -> List[Sample]:
**params.get("parameters", {}),
prob=params.pop("prob", 1.0),
)
new_transformed_samples, removed_samples_tests = filter_unique_samples(
TestFactory.task, transformed_samples, test_name
)
no_transformation_applied_tests.update(removed_samples_tests)
all_samples.extend(new_transformed_samples)

if no_transformation_applied_tests:
logging.warning(
"Removing samples where no transformation has been applied in the following tests: "
+ ", ".join(no_transformation_applied_tests)
)

for sample in transformed_samples:
if test_name != "multiple_perturbations":
sample.test_type = test_name
all_samples.extend(transformed_samples)
return all_samples

@staticmethod
Expand Down Expand Up @@ -675,16 +684,25 @@ def transform(self) -> List[Sample]:
A list of `Sample` objects representing the resulting dataset after running the bias test.
"""
all_samples = []
no_transformation_applied_tests = set()
for test_name, params in self.tests.items():
data_handler_copy = [x.copy() for x in self._data_handler]

transformed_samples = self.supported_tests[test_name].transform(
data_handler_copy, **params.get("parameters", {})
)

for sample in transformed_samples:
sample.test_type = test_name
all_samples.extend(transformed_samples)
new_transformed_samples, removed_samples_tests = filter_unique_samples(
TestFactory.task, transformed_samples, test_name
)
no_transformation_applied_tests.update(removed_samples_tests)
all_samples.extend(new_transformed_samples)

if no_transformation_applied_tests:
logging.warning(
"Removing samples where no transformation has been applied in the following tests: "
+ ", ".join(no_transformation_applied_tests)
)

return all_samples

Expand Down
48 changes: 48 additions & 0 deletions langtest/transform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,51 @@ def check_name(word: str, name_lists: List[List[str]]) -> bool:
return any(
word.lower() in [name.lower() for name in name_list] for name_list in name_lists
)


def filter_unique_samples(task: str, transformed_samples: list, test_name: str):
"""
Filter and remove samples with no applied transformations from the list of transformed_samples.

Args:
task (str): The type of task.
transformed_samples (list): List of transformed samples to be filtered.
test_name (str): Name of the test.

Returns:
new_transformed_samples (list): List of filtered samples with unique transformations.
no_transformation_applied_tests (set): Set of test names for which no transformation
was applied due to non-uniqueness.
"""
no_transformation_applied_tests = set()
new_transformed_samples = []
if task == "question-answering":
for sample in transformed_samples:
if (
sample.original_question.replace(" ", "")
!= sample.perturbed_question.replace(" ", "")
) or (
sample.original_context.replace(" ", "")
!= sample.perturbed_context.replace(" ", "")
):
if test_name != "multiple_perturbations":
sample.test_type = test_name
new_transformed_samples.append(sample)
else:
if test_name == "multiple_perturbations":
no_transformation_applied_tests.add(sample.test_type)
else:
no_transformation_applied_tests.add(test_name)
else:
for sample in transformed_samples:
if sample.original.replace(" ", "") != sample.test_case.replace(" ", ""):
if test_name != "multiple_perturbations":
sample.test_type = test_name
new_transformed_samples.append(sample)
else:
if test_name == "multiple_perturbations":
no_transformation_applied_tests.add(sample.test_type)
else:
no_transformation_applied_tests.add(test_name)

return new_transformed_samples, no_transformation_applied_tests
2 changes: 1 addition & 1 deletion tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_csv_dataset_textclassification_hf(self):
harness.data = harness.data[:50]
report = harness.generate().run().report()
self.assertIsInstance(report, pd.DataFrame)
custom_proportions = {"uppercase": 0.8, "lowercase": 0.8}
custom_proportions = {"uppercase": 0.8}
harness.augment(
training_data={"data_source": "tests/fixtures/text_classification.csv"},
save_data_path="tests/fixtures/augmented_text_classification.csv",
Expand Down
16 changes: 16 additions & 0 deletions tests/test_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,22 @@ def test_ner_csv_custom_columns(self):
self.assertEqual(tc_harness.data, loaded_tc_harness.data)
self.assertNotEqual(tc_harness.model, loaded_tc_harness.model)

def test_filtering_Out_Same_Original_And_TestCase(self):
"""
Test filtering out records where 'original' and 'test_case' are the same for text classification task.
"""
save_dir = "/tmp/saved_text_classification_harness_test"
tc_harness = Harness(
task="text-classification",
model={"model": "bert-base-cased", "hub": "huggingface"},
data={"data_source": "tests/fixtures/text_classification.csv"},
config="tests/fixtures/config_text_classification.yaml",
)
tc_harness.generate()
df = tc_harness.testcases()
filtered_df = df[df["original"] == df["test_case"]]
self.assertTrue(filtered_df.empty)


class DefaultCodeBlocksTestCase(unittest.TestCase):
"""
Expand Down