diff --git a/langtest/augmentation/__init__.py b/langtest/augmentation/__init__.py index e84ffdf71..966b0a659 100644 --- a/langtest/augmentation/__init__.py +++ b/langtest/augmentation/__init__.py @@ -162,6 +162,8 @@ def fix( res = TestFactory.transform( self.task, [hash_map[each]], test_type ) + if len(res) == 0: + continue hash_map[each] = res[0] else: if test == "swap_entities": diff --git a/langtest/transform/__init__.py b/langtest/transform/__init__.py index 0fe99a300..fd6420a9c 100644 --- a/langtest/transform/__init__.py +++ b/langtest/transform/__init__.py @@ -1,6 +1,7 @@ import asyncio import copy import time +import logging from collections import defaultdict from abc import ABC, abstractmethod from typing import Dict, List, Union @@ -33,7 +34,7 @@ religion_wise_names, white_names, ) -from .utils import get_substitution_names, create_terminology +from .utils import get_substitution_names, create_terminology, filter_unique_samples from ..modelhandler import ModelFactory from ..utils.custom_types.sample import ( NERSample, @@ -365,7 +366,8 @@ def transform(self) -> List[Sample]: A list of `Sample` objects representing the resulting dataset after running the robustness test. """ all_samples = [] - tests_copy = self.tests.copy() # Create a copy of self.tests + no_transformation_applied_tests = set() + tests_copy = self.tests.copy() for test_name, params in tests_copy.items(): if TestFactory.is_augment: data_handler_copy = [x.copy() for x in self._data_handler] @@ -495,11 +497,18 @@ def transform(self) -> List[Sample]: **params.get("parameters", {}), prob=params.pop("prob", 1.0), ) + new_transformed_samples, removed_samples_tests = filter_unique_samples( + TestFactory.task, transformed_samples, test_name + ) + no_transformation_applied_tests.update(removed_samples_tests) + all_samples.extend(new_transformed_samples) + + if no_transformation_applied_tests: + logging.warning( + "Removing samples where no transformation has been applied in the following tests: " + + ", ".join(no_transformation_applied_tests) + ) - for sample in transformed_samples: - if test_name != "multiple_perturbations": - sample.test_type = test_name - all_samples.extend(transformed_samples) return all_samples @staticmethod @@ -675,6 +684,7 @@ def transform(self) -> List[Sample]: A list of `Sample` objects representing the resulting dataset after running the bias test. """ all_samples = [] + no_transformation_applied_tests = set() for test_name, params in self.tests.items(): data_handler_copy = [x.copy() for x in self._data_handler] @@ -682,9 +692,17 @@ def transform(self) -> List[Sample]: data_handler_copy, **params.get("parameters", {}) ) - for sample in transformed_samples: - sample.test_type = test_name - all_samples.extend(transformed_samples) + new_transformed_samples, removed_samples_tests = filter_unique_samples( + TestFactory.task, transformed_samples, test_name + ) + no_transformation_applied_tests.update(removed_samples_tests) + all_samples.extend(new_transformed_samples) + + if no_transformation_applied_tests: + logging.warning( + "Removing samples where no transformation has been applied in the following tests: " + + ", ".join(no_transformation_applied_tests) + ) return all_samples diff --git a/langtest/transform/utils.py b/langtest/transform/utils.py index 4740ba5eb..7d127b00c 100644 --- a/langtest/transform/utils.py +++ b/langtest/transform/utils.py @@ -351,3 +351,51 @@ def check_name(word: str, name_lists: List[List[str]]) -> bool: return any( word.lower() in [name.lower() for name in name_list] for name_list in name_lists ) + + +def filter_unique_samples(task: str, transformed_samples: list, test_name: str): + """ + Filter and remove samples with no applied transformations from the list of transformed_samples. + + Args: + task (str): The type of task. + transformed_samples (list): List of transformed samples to be filtered. + test_name (str): Name of the test. + + Returns: + new_transformed_samples (list): List of filtered samples with unique transformations. + no_transformation_applied_tests (set): Set of test names for which no transformation + was applied due to non-uniqueness. + """ + no_transformation_applied_tests = set() + new_transformed_samples = [] + if task == "question-answering": + for sample in transformed_samples: + if ( + sample.original_question.replace(" ", "") + != sample.perturbed_question.replace(" ", "") + ) or ( + sample.original_context.replace(" ", "") + != sample.perturbed_context.replace(" ", "") + ): + if test_name != "multiple_perturbations": + sample.test_type = test_name + new_transformed_samples.append(sample) + else: + if test_name == "multiple_perturbations": + no_transformation_applied_tests.add(sample.test_type) + else: + no_transformation_applied_tests.add(test_name) + else: + for sample in transformed_samples: + if sample.original.replace(" ", "") != sample.test_case.replace(" ", ""): + if test_name != "multiple_perturbations": + sample.test_type = test_name + new_transformed_samples.append(sample) + else: + if test_name == "multiple_perturbations": + no_transformation_applied_tests.add(sample.test_type) + else: + no_transformation_applied_tests.add(test_name) + + return new_transformed_samples, no_transformation_applied_tests diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index caba3df32..d4926d709 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -180,7 +180,7 @@ def test_csv_dataset_textclassification_hf(self): harness.data = harness.data[:50] report = harness.generate().run().report() self.assertIsInstance(report, pd.DataFrame) - custom_proportions = {"uppercase": 0.8, "lowercase": 0.8} + custom_proportions = {"uppercase": 0.8} harness.augment( training_data={"data_source": "tests/fixtures/text_classification.csv"}, save_data_path="tests/fixtures/augmented_text_classification.csv", diff --git a/tests/test_harness.py b/tests/test_harness.py index 3dc1b12e3..2f2997e54 100644 --- a/tests/test_harness.py +++ b/tests/test_harness.py @@ -268,6 +268,22 @@ def test_ner_csv_custom_columns(self): self.assertEqual(tc_harness.data, loaded_tc_harness.data) self.assertNotEqual(tc_harness.model, loaded_tc_harness.model) + def test_filtering_Out_Same_Original_And_TestCase(self): + """ + Test filtering out records where 'original' and 'test_case' are the same for text classification task. + """ + save_dir = "/tmp/saved_text_classification_harness_test" + tc_harness = Harness( + task="text-classification", + model={"model": "bert-base-cased", "hub": "huggingface"}, + data={"data_source": "tests/fixtures/text_classification.csv"}, + config="tests/fixtures/config_text_classification.yaml", + ) + tc_harness.generate() + df = tc_harness.testcases() + filtered_df = df[df["original"] == df["test_case"]] + self.assertTrue(filtered_df.empty) + class DefaultCodeBlocksTestCase(unittest.TestCase): """