Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions langtest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
)
from langtest.utils.custom_types.helpers import default_user_prompt
from langtest.errors import Errors
from langtest.utils.util_metrics import calculate_f1_score, classification_report
from langtest.utils.util_metrics import (
calculate_f1_score,
calculate_f1_score_multi_label,
classification_report,
classification_report_multi_label,
)


class AccuracyTestFactory(ITests):
Expand Down Expand Up @@ -151,6 +156,7 @@ def predict_ner(sample):
y_true = y_true.apply(lambda x: x.split("-")[-1])

elif isinstance(raw_data_copy[0], SequenceClassificationSample):
is_mutli_label = raw_data_copy[0].expected_results.multi_label

def predict_text_classification(sample):
prediction = model.predict(sample.original)
Expand All @@ -166,11 +172,16 @@ def predict_text_classification(sample):
y_pred = pd.Series(raw_data_copy).apply(
lambda x: [y.label for y in x.actual_results.predictions]
)
y_true = y_true.apply(lambda x: x[0])
y_pred = y_pred.apply(lambda x: x[0])

y_true = y_true.explode()
y_pred = y_pred.explode()
if is_mutli_label:
kwargs["is_multi_label"] = is_mutli_label

else:
y_true = y_true.apply(lambda x: x[0])
y_pred = y_pred.apply(lambda x: x[0])

y_true = y_true.explode()
y_pred = y_pred.explode()

elif raw_data_copy[0].task == "question-answering":
from ..utils.custom_types.helpers import build_qa_input, build_qa_prompt
Expand Down Expand Up @@ -374,7 +385,13 @@ async def run(
y_pred (List[Any]): Predicted values
"""
progress = kwargs.get("progress_bar", False)
df_metrics = classification_report(y_true, y_pred, zero_division=0)
is_multi_label = kwargs.get("is_multi_label", False)
if is_multi_label:
df_metrics = classification_report_multi_label(
y_true, y_pred, zero_division=0
)
else:
df_metrics = classification_report(y_true, y_pred, zero_division=0)
df_metrics.pop("macro avg")

for idx, sample in enumerate(sample_list):
Expand Down Expand Up @@ -454,7 +471,13 @@ async def run(
"""
progress = kwargs.get("progress_bar", False)

df_metrics = classification_report(y_true, y_pred, zero_division=0)
is_multi_label = kwargs.get("is_multi_label", False)
if is_multi_label:
df_metrics = classification_report_multi_label(
y_true, y_pred, zero_division=0
)
else:
df_metrics = classification_report(y_true, y_pred, zero_division=0)
df_metrics.pop("macro avg")

for idx, sample in enumerate(sample_list):
Expand Down Expand Up @@ -531,8 +554,13 @@ async def run(

"""
progress = kwargs.get("progress_bar", False)

df_metrics = classification_report(y_true, y_pred, zero_division=0)
is_multi_label = kwargs.get("is_multi_label", False)
if is_multi_label:
df_metrics = classification_report_multi_label(
y_true, y_pred, zero_division=0
)
else:
df_metrics = classification_report(y_true, y_pred, zero_division=0)
df_metrics.pop("macro avg")

for idx, sample in enumerate(sample_list):
Expand Down Expand Up @@ -599,8 +627,14 @@ async def run(

"""
progress = kwargs.get("progress_bar", False)
is_multi_label = kwargs.get("is_multi_label", False)

f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0)
if is_multi_label:
f1 = calculate_f1_score_multi_label(
y_true, y_pred, average="micro", zero_division=0
)
else:
f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0)

for sample in sample_list:
sample.actual_results = MinScoreOutput(min_score=f1)
Expand Down Expand Up @@ -664,7 +698,14 @@ async def run(
"""
progress = kwargs.get("progress_bar", False)

f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0)
is_multi_label = kwargs.get("is_multi_label", False)

if is_multi_label:
f1 = calculate_f1_score_multi_label(
y_true, y_pred, average="macro", zero_division=0
)
else:
f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0)

for sample in sample_list:
sample.actual_results = MinScoreOutput(min_score=f1)
Expand Down Expand Up @@ -726,7 +767,14 @@ async def run(

"""
progress = kwargs.get("progress_bar", False)
f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0)
is_multi_label = kwargs.get("is_multi_label", False)

if is_multi_label:
f1 = calculate_f1_score_multi_label(
y_true, y_pred, average="weighted", zero_division=0
)
else:
f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0)

for sample in sample_list:
sample.actual_results = MinScoreOutput(min_score=f1)
Expand Down
8 changes: 6 additions & 2 deletions langtest/utils/custom_types/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,12 @@ def prepare_model_response(self, data):

if data[0].task == "text-classification":
for sample in data:
sample.actual_results = sample.actual_results.predictions[0]
sample.expected_results = sample.expected_results.predictions[0]
if sample.expected_results.multi_label:
sample.actual_results = sample.actual_results
sample.expected_results = sample.expected_results
else:
sample.actual_results = sample.actual_results.predictions[0]
sample.expected_results = sample.expected_results.predictions[0]
elif data[0].task == "ner":
for sample in data:
sample.actual_results = sample.actual_results.predictions
Expand Down
2 changes: 2 additions & 0 deletions langtest/utils/custom_types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def to_str_list(self) -> str:

def __str__(self) -> str:
"""String representation"""
if self.multi_label:
return self.to_str_list()
labels = {elt.label: elt.score for elt in self.predictions}
return f"SequenceClassificationOutput(predictions={labels})"

Expand Down
226 changes: 224 additions & 2 deletions langtest/utils/util_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import Counter
from typing import List, Union, Dict
from typing import List, Set, Union, Dict
from ..errors import Errors


def classification_report(
y_true: List[Union[str, int]], y_pred: List[Union[str, int]], zero_division: int = 0
y_true: List[Union[str, int]],
y_pred: List[Union[str, int]],
zero_division: int = 0,
multi_label: bool = False,
) -> Dict[str, Dict[str, Union[float, int]]]:
"""Generate a classification report including precision, recall, f1-score, and support.

Expand Down Expand Up @@ -170,3 +173,222 @@ def calculate_f1_score(
else:
raise ValueError(Errors.E074)
return f1_score


def simple_multilabel_binarizer(y_true, y_pred):
"""
A simple implementation of a multilabel binarizer for y_true and y_pred.

Args:
y_true (list of lists or sets): Actual labels for the data.
y_pred (list of lists or sets): Predicted labels for the data.

Returns:
binarized_y_true (list of lists): Binary matrix of true labels.
binarized_y_pred (list of lists): Binary matrix of predicted labels.
classes (list): List of all unique classes (labels).
"""
# Ensure we collect unique classes from both y_true and y_pred
classes = sorted(set(label for labels in y_true + y_pred for label in labels))

# Create a binary matrix for y_true and y_pred
y_true_bin = [[1 if label in labels else 0 for label in classes] for labels in y_true]
y_pred_bin = [[1 if label in labels else 0 for label in classes] for labels in y_pred]

# Return the binarized labels and the consistent set of classes
return y_true_bin, y_pred_bin, classes


def classification_report_multi_label(
y_true: List[Set[Union[str, int]]],
y_pred: List[Set[Union[str, int]]],
zero_division: int = 0,
) -> Dict[str, Dict[str, Union[float, int]]]:
"""
Generate a classification report for multi-label classification.

Args:
y_true (List[Set[Union[str, int]]]): List of sets of true labels.
y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels.
zero_division (int, optional): Specifies the value to return when there is a zero division case. Defaults to 0.

Returns:
Dict[str, Dict[str, Union[float, int]]]: Classification report.
"""
# Binarize the multi-label data
y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred)

# Initialize data structure for the report
report = {}
for i, class_label in enumerate(classes):
support = sum(row[i] for row in y_true_bin)
predicted_labels = sum(row[i] for row in y_pred_bin)
correct_predictions = sum(
1
for true_row, pred_row in zip(y_true_bin, y_pred_bin)
if true_row[i] == pred_row[i] == 1
)

# Precision, recall, and F1 score calculations
if predicted_labels > 0:
precision = correct_predictions / predicted_labels
else:
precision = zero_division

if support > 0:
recall = correct_predictions / support
else:
recall = zero_division

if (precision + recall) > 0:
f1_score = (2 * precision * recall) / (precision + recall)
else:
f1_score = zero_division

# Add stats to the report
report[class_label] = {
"precision": precision,
"recall": recall,
"f1-score": f1_score,
"support": support,
}

# Compute macro averages
avg_precision = sum([metrics["precision"] for metrics in report.values()]) / len(
report
)
avg_recall = sum([metrics["recall"] for metrics in report.values()]) / len(report)
avg_f1_score = sum([metrics["f1-score"] for metrics in report.values()]) / len(report)

report["macro avg"] = {
"precision": avg_precision,
"recall": avg_recall,
"f1-score": avg_f1_score,
"support": len(y_true),
}

return report


def calculate_f1_score_multi_label(
y_true: List[Set[Union[str, int]]],
y_pred: List[Set[Union[str, int]]],
average: str = "macro",
zero_division: int = 0,
) -> float:
"""
Calculate the F1 score for multi-label classification using binarized labels.

Args:
y_true (List[Set[Union[str, int]]]): List of sets of true labels.
y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels.
average (str, optional): Method to calculate F1 score, can be 'micro', 'macro', or 'weighted'. Defaults to 'macro'.
zero_division (int, optional): Value to return when there is a zero division case. Defaults to 0.

Returns:
float: Calculated F1 score for multi-label classification.

Raises:
AssertionError: If lengths of y_true and y_pred are not equal.
ValueError: If invalid averaging method is provided.
"""
assert len(y_true) == len(y_pred), "Lengths of y_true and y_pred must be equal."

# Binarize the labels and get the unique class set
y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred)

# Number of classes should remain consistent
num_classes = len(classes)

if average == "micro":
true_positives = sum(
1
for i in range(len(y_true_bin))
for j in range(num_classes)
if y_true_bin[i][j] == y_pred_bin[i][j] == 1
)
false_positives = sum(
1
for i in range(len(y_true_bin))
for j in range(num_classes)
if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0
)
false_negatives = sum(
1
for i in range(len(y_true_bin))
for j in range(num_classes)
if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1
)

precision = (
true_positives / (true_positives + false_positives)
if (true_positives + false_positives) > 0
else zero_division
)
recall = (
true_positives / (true_positives + false_negatives)
if (true_positives + false_negatives) > 0
else zero_division
)
f1_score = (
2 * (precision * recall) / (precision + recall)
if (precision + recall) > 0
else zero_division
)

elif average in ["macro", "weighted"]:
f1_score = 0.0
total_support = sum(
sum(y_true_bin[i][j] for i in range(len(y_true_bin)))
for j in range(num_classes)
)

for j in range(num_classes):
true_positives = sum(
1
for i in range(len(y_true_bin))
if y_true_bin[i][j] == y_pred_bin[i][j] == 1
)
false_positives = sum(
1
for i in range(len(y_true_bin))
if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0
)
false_negatives = sum(
1
for i in range(len(y_true_bin))
if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1
)

precision = (
true_positives / (true_positives + false_positives)
if (true_positives + false_positives) > 0
else zero_division
)
recall = (
true_positives / (true_positives + false_negatives)
if (true_positives + false_negatives) > 0
else zero_division
)

if precision + recall > 0:
class_f1_score = 2 * (precision * recall) / (precision + recall)
else:
class_f1_score = 0.0

# Support for the current class (how many times it appears in y_true)
support = sum(y_true_bin[i][j] for i in range(len(y_true_bin)))

if average == "macro":
f1_score += class_f1_score / num_classes
elif average == "weighted":
# Normalize weights by dividing the support by the total number of labels
weight = support / total_support if total_support > 0 else 0
f1_score += weight * class_f1_score

else:
raise ValueError(
"Invalid averaging method. Must be 'micro', 'macro', or 'weighted'."
)

return min(f1_score, 1.0) # Ensure the F1 score is capped at 1.0