diff --git a/langtest/augmentation/__init__.py b/langtest/augmentation/__init__.py index e84ffdf71..2d77cc3c6 100644 --- a/langtest/augmentation/__init__.py +++ b/langtest/augmentation/__init__.py @@ -437,6 +437,10 @@ def new_sample(self, template: Sample): ) for result in template.expected_results.predictions[cursor:]: if prediction[0].entity.endswith(result.entity): + for each_prediction in prediction: + if isinstance(each_prediction, NERPrediction): + each_prediction.chunk_tag = "-X-" + each_prediction.pos_tag = "-X-" other_predictions.extend(prediction) cursor += 1 break diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 883e13fe3..d36382911 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -407,8 +407,9 @@ def export_data(self, data: List[NERSample], output_path: str): path to save the data to """ otext = "" + temp_id = None for i in data: - text = Formatter.process(i, output_format="conll") + text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id) otext += text + "\n" with open(output_path, "wb") as fwriter: diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 3552c5cf1..72cacbaf7 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -158,9 +158,7 @@ def to_csv( return tokens, labels, [], [] @staticmethod - def to_conll( - sample: NERSample, writing_mode: str = "ignore" - ) -> Union[str, Tuple[str, str]]: + def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, str]]: """Converts a custom type to a CoNLL string. Args: @@ -176,36 +174,47 @@ def to_conll( Returns: The CoNLL string representation of the custom type. """ - assert writing_mode in [ - "ignore", - "append", - "separate", - ], f"writing_mode: {writing_mode} not supported." - text, text_perturbed = "", "" + text = "" test_case = sample.test_case original = sample.original - - words = re.finditer(r"([^\s]+)", original) - - for word in words: - token = word.group() - match = sample.expected_results[word.group()] - label = match.entity if match is not None else "O" - text += f"{token} -X- -X- {label}\n" - - if test_case and writing_mode != "ignore": - words = re.finditer(r"([^\s]+)", test_case) - - for word in words: - token = word.group() - match = sample.actual_results[word.group()] - label = match.entity if match is not None else "O" - if writing_mode == "append": - text += f"{token} -X- -X- {label}\n" - elif writing_mode == "separate": - text_perturbed += f"{token} -X- -X- {label}\n" - - if writing_mode == "separate": - return text, text_perturbed - return text + if test_case: + test_case_items = test_case.split() + norm_test_case_items = test_case.lower().split() + norm_original_items = original.lower().split() + temp_len = 0 + for jdx, item in enumerate(norm_test_case_items): + try: + if item in norm_original_items and jdx >= norm_original_items.index( + item + ): + oitem_index = norm_original_items.index(item) + j = sample.expected_results.predictions[oitem_index + temp_len] + if temp_id != j.doc_id and jdx == 0: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + norm_original_items.pop(oitem_index) + temp_len += 1 + else: + o_item = sample.expected_results.predictions[jdx].span.word + letters_count = len(set(item) - set(o_item)) + if ( + len(norm_test_case_items) == len(original.lower().split()) + or letters_count < 2 + ): + tl = sample.expected_results.predictions[jdx] + text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n" + else: + text += f"{test_case_items[jdx]} -X- -X- O\n" + except IndexError: + text += f"{test_case_items[jdx]} -X- -X- O\n" + + else: + for j in sample.expected_results.predictions: + if temp_id != j.doc_id: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + + return text, temp_id diff --git a/langtest/transform/robustness.py b/langtest/transform/robustness.py index 3cfb15a7d..d3e7a526b 100644 --- a/langtest/transform/robustness.py +++ b/langtest/transform/robustness.py @@ -487,13 +487,17 @@ def transform( sample.category = "robustness" if all([label == "O" for label in sample_labels]): sample.test_case = sample.original - continue + break sent_tokens = sample.original.split(" ") ent_start_pos = [1 if label[0] == "B" else 0 for label in sample_labels] ent_idx = [i for i, value in enumerate(ent_start_pos) if value == 1] + if not ent_idx: + sample.test_case = sample.original + break + replace_idx = random.choice(ent_idx) ent_type = sample_labels[replace_idx][2:] replace_idxs = [replace_idx] diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 95b297ead..4a923472c 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -31,7 +31,10 @@ class TestNERDataset: test_type="add_context", expected_results=NEROutput( predictions=[ - NERPrediction(entity="PROD", span=Span(start=10, end=13, word="KFC")) + NERPrediction(entity="O", span=Span(start=10, end=13, word="I")), + NERPrediction(entity="O", span=Span(start=10, end=13, word="do")), + NERPrediction(entity="O", span=Span(start=10, end=13, word="love")), + NERPrediction(entity="PROD", span=Span(start=10, end=13, word="KFC")), ] ), )