Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions langtest/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from langtest.transform import TestFactory
from langtest.tasks.task import TaskManager
from langtest.utils.custom_types.sample import Sample
from langtest.logger import logger


class DataAugmenter:
def __init__(self, task: Union[str, TaskManager], config: Union[str, dict]) -> None:
def __init__(
self,
task: Union[str, TaskManager],
config: Union[str, dict],
) -> None:
"""
Initialize the DataAugmenter.

Expand Down Expand Up @@ -241,11 +246,20 @@ def prepare_hash_map(

return hashmap

def save(self, file_path: str):
def save(self, file_path: str, for_gen_ai=False) -> None:
"""
Save the augmented data.
"""
self.__datafactory.export(data=self.__augmented_data, output_path=file_path)
try:
# .json file allow only for_gen_ai boolean is true and task is ner
# then file_path should be .json
if not (for_gen_ai) and self.__task.task_name == "ner":
if file_path.endswith(".json"):
raise ValueError("File path shouldn't be .json file")

self.__datafactory.export(data=self.__augmented_data, output_path=file_path)
except Exception as e:
logger.error(f"Error in saving the augmented data: {e}")

def __or__(self, other: Iterable):
results = self.augment(other)
Expand Down
52 changes: 44 additions & 8 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..errors import Warnings, Errors
import glob
from pkg_resources import resource_filename
from langtest.logger import logger

COLUMN_MAPPER = {
"text-classification": {
Expand Down Expand Up @@ -551,14 +552,49 @@ def export_data(self, data: List[NERSample], output_path: str):
output_path (str):
path to save the data to
"""
otext = ""
temp_id = None
for i in data:
text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id)
otext += text + "\n"

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))
if output_path.endswith(".conll"):
otext = ""
temp_id = None
for i in data:
text, temp_id = Formatter.process(
i, output_format="conll", temp_id=temp_id
)
otext += text + "\n"

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))

elif output_path.endswith(".json"):
import json
from .utils import process_document

logger.warn("Only for Gen AI Lab use")
logger.info("Converting NER sample to JSON format")

otext_list = []
temp_id = None
for i in data:
otext, temp_id = Formatter.process(
i, output_format="json", temp_id=temp_id
)
processed_text = process_document(otext)
# add test info
tem_dict = processed_text["data"]
tem_dict["test_type"] = i.test_type or "null"
tem_dict["category"] = i.category or "null"

processed_text["data"] = tem_dict
otext_list.append(processed_text)

# otext += text + "\n"
# if temp_id2 != temp_id:
# processed_text = process_document(otext)
# otext_list.append(processed_text)
# otext = ""
# temp_id = temp_id2

with open(output_path, "w") as fwriter:
json.dump(otext_list, fwriter)

def __token_validation(self, tokens: str) -> (bool, List[List[str]]): # type: ignore
"""Validates the tokens in a sentence.
Expand Down
7 changes: 7 additions & 0 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st

return text, temp_id

@staticmethod
def to_json(sample: NERSample, temp_id: int = None) -> dict:
"""Converts a NERSample to a JSON string."""

text, temp_id = NEROutputFormatter.to_conll(sample, temp_id)
return text, temp_id


class QAFormatter(BaseFormatter):
def to_jsonl(sample: QASample, *args, **kwargs):
Expand Down
116 changes: 116 additions & 0 deletions langtest/datahandler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from datetime import datetime


def get_results(tokens, labels, text):
current_entity = None
current_span = []
results = []
char_pos = 0 # Tracks the character position in the text

for i, (token, label) in enumerate(zip(tokens, labels)):
token_start = char_pos
token_end = token_start + len(token)
if label.startswith("B-"):
if current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
current_entity = label[2:]
current_span = [token_start, token_end]
elif label.startswith("I-") and current_entity:
current_span[-1] = token_end
elif label == "O" and current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
current_entity = None
current_span = []

# Move to the next character position (account for the space between tokens)
char_pos = (
token_end + 1
if i + 1 < len(tokens) and tokens[i + 1] not in [".", ",", "!", "?"]
else token_end
)

if current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
return results


def process_document(doc):
tokens = []
labels = []

# replace the -DOCSTART- tag with a newline
doc = doc.replace("-DOCSTART-", "")

for line in doc.strip().split("\n"):
if line.strip():
parts = line.strip().split()
if len(parts) == 4:
token, _, _, label = parts
tokens.append(token)
labels.append(label)

text = ""
for _, token in enumerate(tokens):
if token in {".", ",", "!", "?"}:
text = text.rstrip() + token + " "
else:
text += token + " "

text = text.rstrip()

results = get_results(tokens, labels, text)
now = datetime.utcnow()
current_date = now.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
json_output = {
"created_ago": current_date,
"result": results,
"honeypot": True,
"lead_time": 10,
"confidence_range": [0, 1],
"submitted_at": current_date,
"updated_at": current_date,
"predictions": [],
"created_at": current_date,
"data": {"text": text},
}

return json_output