Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langtest/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ def new_sample(self, template: Sample):
)
for result in template.expected_results.predictions[cursor:]:
if prediction[0].entity.endswith(result.entity):
for each_prediction in prediction:
if isinstance(each_prediction, NERPrediction):
each_prediction.chunk_tag = "-X-"
each_prediction.pos_tag = "-X-"
other_predictions.extend(prediction)
cursor += 1
break
Expand Down
3 changes: 2 additions & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,9 @@ def export_data(self, data: List[NERSample], output_path: str):
path to save the data to
"""
otext = ""
temp_id = None
for i in data:
text = Formatter.process(i, output_format="conll")
text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id)
otext += text + "\n"

with open(output_path, "wb") as fwriter:
Expand Down
75 changes: 42 additions & 33 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ def to_csv(
return tokens, labels, [], []

@staticmethod
def to_conll(
sample: NERSample, writing_mode: str = "ignore"
) -> Union[str, Tuple[str, str]]:
def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, str]]:
"""Converts a custom type to a CoNLL string.

Args:
Expand All @@ -176,36 +174,47 @@ def to_conll(
Returns:
The CoNLL string representation of the custom type.
"""
assert writing_mode in [
"ignore",
"append",
"separate",
], f"writing_mode: {writing_mode} not supported."

text, text_perturbed = "", ""
text = ""
test_case = sample.test_case
original = sample.original

words = re.finditer(r"([^\s]+)", original)

for word in words:
token = word.group()
match = sample.expected_results[word.group()]
label = match.entity if match is not None else "O"
text += f"{token} -X- -X- {label}\n"

if test_case and writing_mode != "ignore":
words = re.finditer(r"([^\s]+)", test_case)

for word in words:
token = word.group()
match = sample.actual_results[word.group()]
label = match.entity if match is not None else "O"
if writing_mode == "append":
text += f"{token} -X- -X- {label}\n"
elif writing_mode == "separate":
text_perturbed += f"{token} -X- -X- {label}\n"

if writing_mode == "separate":
return text, text_perturbed
return text
if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
try:
if item in norm_original_items and jdx >= norm_original_items.index(
item
):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if (
len(norm_test_case_items) == len(original.lower().split())
or letters_count < 2
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"

return text, temp_id
6 changes: 5 additions & 1 deletion langtest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,17 @@ def transform(
sample.category = "robustness"
if all([label == "O" for label in sample_labels]):
sample.test_case = sample.original
continue
break

sent_tokens = sample.original.split(" ")

ent_start_pos = [1 if label[0] == "B" else 0 for label in sample_labels]
ent_idx = [i for i, value in enumerate(ent_start_pos) if value == 1]

if not ent_idx:
sample.test_case = sample.original
break

replace_idx = random.choice(ent_idx)
ent_type = sample_labels[replace_idx][2:]
replace_idxs = [replace_idx]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class TestNERDataset:
test_type="add_context",
expected_results=NEROutput(
predictions=[
NERPrediction(entity="PROD", span=Span(start=10, end=13, word="KFC"))
NERPrediction(entity="O", span=Span(start=10, end=13, word="I")),
NERPrediction(entity="O", span=Span(start=10, end=13, word="do")),
NERPrediction(entity="O", span=Span(start=10, end=13, word="love")),
NERPrediction(entity="PROD", span=Span(start=10, end=13, word="KFC")),
]
),
)
Expand Down