Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelFor
try:
cls._update_model_parameters(hub, filtered_kwargs)
if path in (
"gpt-4o",
"gpt-4o-mini",
"gpt-4",
"gpt-3.5-turbo",
"gpt-4-1106-preview",
Expand Down
101 changes: 101 additions & 0 deletions langtest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,3 +1938,104 @@ def randomize_ages(text):
perturbed_samples.append(s)

return perturbed_samples


class AddNewLines(BaseRobustness):
alias_name = "add_new_lines"

@staticmethod
def transform(
sample_list: List[Sample],
prob: Optional[float] = 1.0,
count: int = 1,
max_lines: int = 3,
) -> List[Sample]:
"""Transforms the given sample list by adding new lines to the input text.

Args:
sample_list (List[Sample]): The list of samples to transform.
prob (Optional[float]): The probability controlling the proportion of samples to be perturbed.
Defaults to 0.2.
count: Number of variations to create.
max_lines: Maximum number of lines to add.

Returns:
List[Sample]: The transformed list of samples with new lines added.
"""

def add_new_lines(text: str) -> Tuple[str, List[Transformation]]:
transformations = []

# Find all tokens and their positions
tokens = []
for match in re.finditer(r"\S+", text):
word = match.group()
start = match.start()
end = match.end()
tokens.append({"word": word, "start": start, "end": end})

if len(tokens) < 5:
return text, transformations

# Decide which tokens to transform
transformed_indices = []
for i, _ in enumerate(tokens):
if random.random() < prob:
transformed_indices.append(i)

# randomly select the transformed indices
transformed_indices = random.sample(
transformed_indices,
min(
2 * len(text) // len(transformed_indices),
int(len(transformed_indices) * prob),
),
)

# Build the perturbed text and record transformations
perturbed_text = ""
prev_end = 0
for i, token in enumerate(tokens):
# Add any intermediate spaces or punctuation
perturbed_text += text[prev_end : token["start"]]

perturbed_start = len(perturbed_text)
perturbed_word = token["word"]

if i in transformed_indices:
perturbed_word += "\n" * max(1, random.randint(1, max_lines))
transformations.append(
Transformation(
original_span=Span(
start=token["start"], end=token["end"], word=token["word"]
),
new_span=Span(
start=perturbed_start,
end=perturbed_start + len(perturbed_word),
word=perturbed_word,
),
ignore=False,
)
)

perturbed_text += perturbed_word
prev_end = token["end"]

# Add any remaining text after the last token
perturbed_text += text[prev_end:]

return perturbed_text, transformations

perturbed_samples = []
for s in sample_list:
for _ in range(count):
sample = deepcopy(s)
if isinstance(sample, str):
sample, _ = add_new_lines(sample)
else:
sample.test_case, transformations = add_new_lines(sample.original)
if sample.task in ("ner", "text-classification"):
sample.transformations = transformations
sample.category = "robustness"
perturbed_samples.append(sample)
return perturbed_samples