Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions langtest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,10 @@ def randomize_ages(text):


class AddNewLines(BaseRobustness):
"""A class for adding new lines to the input text."""

alias_name = "add_new_lines"
supported_tasks = ["text-classification", "question-answering", "summarization"]

@staticmethod
def transform(
Expand Down Expand Up @@ -2039,3 +2042,123 @@ def add_new_lines(text: str) -> Tuple[str, List[Transformation]]:
sample.category = "robustness"
perturbed_samples.append(sample)
return perturbed_samples


class AddTabs(BaseRobustness):
"""A class for adding tabs to the input text."""

alias_name = "add_tabs"
supported_tasks = ["text-classification", "question-answering", "summarization"]

@staticmethod
def transform(
sample_list: List[Sample],
prob: Optional[float] = 1.0,
count: int = 1,
max_tabs: int = 5,
) -> List[Sample]:
"""Transforms the given sample list by adding tabs to the input text.

Args:
sample_list (List[Sample]): The list of samples to transform.
prob (Optional[float]): The probability controlling the proportion of samples to be perturbed.
Defaults to 1.0, which means all samples will be transformed.
count: Number of variations to create.

Returns:
List[Sample]: The transformed list of samples with tabs added.
"""

def add_tabs(text: str) -> Tuple[str, List[Transformation]]:
"""
Inserts a random number of tab characters ('\t') after specific tokens in the input text
based on a given probability and a maximum limit on the number of tab insertions per token.

Args:
text (str): The input text to modify.
prob (float): Probability of adding tabs after a token.
max_tabs (int): Maximum number of tabs to insert after a token.

Returns:
Tuple[str, List[Transformation]]: The modified text and a list of transformations applied.
"""
transformations = []
perturbed_text = ""
prev_end = 0
offset = 0 # Track the number of extra characters added (tabs) to avoid breaking words

# Find all tokens and their positions
tokens = []
for match in re.finditer(r"\S+", text):
word = match.group()
start = match.start()
end = match.end()
tokens.append({"word": word, "start": start, "end": end})

# If there are too few tokens, no transformation is applied
if len(tokens) < 5:
return text, transformations

# Decide which tokens to transform
transformed_indices = []
for i, _ in enumerate(tokens):
if random.random() < prob:
transformed_indices.append(i)

# Build the perturbed text and record transformations
for i, token in enumerate(tokens):
# Add any intermediate spaces or punctuation
perturbed_text += text[prev_end : token["start"]]

perturbed_start = len(perturbed_text)
perturbed_word = token["word"]

# If the token is selected for transformation, add tabs after it
if i in transformed_indices:
tabs_to_insert = "\t" * random.randint(1, max_tabs)
perturbed_word_with_tabs = perturbed_word + tabs_to_insert

# Record the transformation details
transformations.append(
Transformation(
original_span=Span(
start=token["start"], end=token["end"], word=token["word"]
),
new_span=Span(
start=perturbed_start,
end=perturbed_start + len(perturbed_word_with_tabs),
word=perturbed_word_with_tabs,
),
ignore=False,
)
)

# Add the perturbed word with tabs
perturbed_text += perturbed_word_with_tabs

# Adjust the offset for future transformations (since we added tabs)
offset += len(tabs_to_insert)
else:
# Add the token without tabs if not selected for transformation
perturbed_text += perturbed_word

prev_end = token["end"] # Track the end of the current token

# Add any remaining text after the last token
perturbed_text += text[prev_end:]

return perturbed_text, transformations

perturbed_samples = []
for s in sample_list:
for i in range(count):
sample = deepcopy(s)
if isinstance(sample, str):
sample, _ = add_tabs(sample)
else:
sample.test_case, transformations = add_tabs(sample.original)
if sample.task in ("ner", "text-classification"):
sample.transformations = transformations
sample.category = "robustness"
perturbed_samples.append(sample)
return perturbed_samples