diff --git a/langtest/transform/accuracy.py b/langtest/transform/accuracy.py index eb78de8c6..c9f4ccfc5 100644 --- a/langtest/transform/accuracy.py +++ b/langtest/transform/accuracy.py @@ -15,7 +15,12 @@ ) from langtest.utils.custom_types.helpers import default_user_prompt from langtest.errors import Errors -from langtest.utils.util_metrics import calculate_f1_score, classification_report +from langtest.utils.util_metrics import ( + calculate_f1_score, + calculate_f1_score_multi_label, + classification_report, + classification_report_multi_label, +) class AccuracyTestFactory(ITests): @@ -151,6 +156,7 @@ def predict_ner(sample): y_true = y_true.apply(lambda x: x.split("-")[-1]) elif isinstance(raw_data_copy[0], SequenceClassificationSample): + is_mutli_label = raw_data_copy[0].expected_results.multi_label def predict_text_classification(sample): prediction = model.predict(sample.original) @@ -166,11 +172,16 @@ def predict_text_classification(sample): y_pred = pd.Series(raw_data_copy).apply( lambda x: [y.label for y in x.actual_results.predictions] ) - y_true = y_true.apply(lambda x: x[0]) - y_pred = y_pred.apply(lambda x: x[0]) - y_true = y_true.explode() - y_pred = y_pred.explode() + if is_mutli_label: + kwargs["is_multi_label"] = is_mutli_label + + else: + y_true = y_true.apply(lambda x: x[0]) + y_pred = y_pred.apply(lambda x: x[0]) + + y_true = y_true.explode() + y_pred = y_pred.explode() elif raw_data_copy[0].task == "question-answering": from ..utils.custom_types.helpers import build_qa_input, build_qa_prompt @@ -374,7 +385,13 @@ async def run( y_pred (List[Any]): Predicted values """ progress = kwargs.get("progress_bar", False) - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): @@ -454,7 +471,13 @@ async def run( """ progress = kwargs.get("progress_bar", False) - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): @@ -531,8 +554,13 @@ async def run( """ progress = kwargs.get("progress_bar", False) - - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): @@ -599,8 +627,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) + is_multi_label = kwargs.get("is_multi_label", False) - f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0) + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="micro", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) @@ -664,7 +698,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) - f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="macro", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) @@ -726,7 +767,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) - f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="weighted", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) diff --git a/langtest/utils/custom_types/helpers.py b/langtest/utils/custom_types/helpers.py index fa9e43f61..b7439d69d 100644 --- a/langtest/utils/custom_types/helpers.py +++ b/langtest/utils/custom_types/helpers.py @@ -756,8 +756,12 @@ def prepare_model_response(self, data): if data[0].task == "text-classification": for sample in data: - sample.actual_results = sample.actual_results.predictions[0] - sample.expected_results = sample.expected_results.predictions[0] + if sample.expected_results.multi_label: + sample.actual_results = sample.actual_results + sample.expected_results = sample.expected_results + else: + sample.actual_results = sample.actual_results.predictions[0] + sample.expected_results = sample.expected_results.predictions[0] elif data[0].task == "ner": for sample in data: sample.actual_results = sample.actual_results.predictions diff --git a/langtest/utils/custom_types/output.py b/langtest/utils/custom_types/output.py index 6961e4b0f..0808e92bd 100644 --- a/langtest/utils/custom_types/output.py +++ b/langtest/utils/custom_types/output.py @@ -20,6 +20,8 @@ def to_str_list(self) -> str: def __str__(self) -> str: """String representation""" + if self.multi_label: + return self.to_str_list() labels = {elt.label: elt.score for elt in self.predictions} return f"SequenceClassificationOutput(predictions={labels})" diff --git a/langtest/utils/util_metrics.py b/langtest/utils/util_metrics.py index 4fd960a0f..eb407d34d 100644 --- a/langtest/utils/util_metrics.py +++ b/langtest/utils/util_metrics.py @@ -1,10 +1,13 @@ from collections import Counter -from typing import List, Union, Dict +from typing import List, Set, Union, Dict from ..errors import Errors def classification_report( - y_true: List[Union[str, int]], y_pred: List[Union[str, int]], zero_division: int = 0 + y_true: List[Union[str, int]], + y_pred: List[Union[str, int]], + zero_division: int = 0, + multi_label: bool = False, ) -> Dict[str, Dict[str, Union[float, int]]]: """Generate a classification report including precision, recall, f1-score, and support. @@ -170,3 +173,222 @@ def calculate_f1_score( else: raise ValueError(Errors.E074) return f1_score + + +def simple_multilabel_binarizer(y_true, y_pred): + """ + A simple implementation of a multilabel binarizer for y_true and y_pred. + + Args: + y_true (list of lists or sets): Actual labels for the data. + y_pred (list of lists or sets): Predicted labels for the data. + + Returns: + binarized_y_true (list of lists): Binary matrix of true labels. + binarized_y_pred (list of lists): Binary matrix of predicted labels. + classes (list): List of all unique classes (labels). + """ + # Ensure we collect unique classes from both y_true and y_pred + classes = sorted(set(label for labels in y_true + y_pred for label in labels)) + + # Create a binary matrix for y_true and y_pred + y_true_bin = [[1 if label in labels else 0 for label in classes] for labels in y_true] + y_pred_bin = [[1 if label in labels else 0 for label in classes] for labels in y_pred] + + # Return the binarized labels and the consistent set of classes + return y_true_bin, y_pred_bin, classes + + +def classification_report_multi_label( + y_true: List[Set[Union[str, int]]], + y_pred: List[Set[Union[str, int]]], + zero_division: int = 0, +) -> Dict[str, Dict[str, Union[float, int]]]: + """ + Generate a classification report for multi-label classification. + + Args: + y_true (List[Set[Union[str, int]]]): List of sets of true labels. + y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels. + zero_division (int, optional): Specifies the value to return when there is a zero division case. Defaults to 0. + + Returns: + Dict[str, Dict[str, Union[float, int]]]: Classification report. + """ + # Binarize the multi-label data + y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred) + + # Initialize data structure for the report + report = {} + for i, class_label in enumerate(classes): + support = sum(row[i] for row in y_true_bin) + predicted_labels = sum(row[i] for row in y_pred_bin) + correct_predictions = sum( + 1 + for true_row, pred_row in zip(y_true_bin, y_pred_bin) + if true_row[i] == pred_row[i] == 1 + ) + + # Precision, recall, and F1 score calculations + if predicted_labels > 0: + precision = correct_predictions / predicted_labels + else: + precision = zero_division + + if support > 0: + recall = correct_predictions / support + else: + recall = zero_division + + if (precision + recall) > 0: + f1_score = (2 * precision * recall) / (precision + recall) + else: + f1_score = zero_division + + # Add stats to the report + report[class_label] = { + "precision": precision, + "recall": recall, + "f1-score": f1_score, + "support": support, + } + + # Compute macro averages + avg_precision = sum([metrics["precision"] for metrics in report.values()]) / len( + report + ) + avg_recall = sum([metrics["recall"] for metrics in report.values()]) / len(report) + avg_f1_score = sum([metrics["f1-score"] for metrics in report.values()]) / len(report) + + report["macro avg"] = { + "precision": avg_precision, + "recall": avg_recall, + "f1-score": avg_f1_score, + "support": len(y_true), + } + + return report + + +def calculate_f1_score_multi_label( + y_true: List[Set[Union[str, int]]], + y_pred: List[Set[Union[str, int]]], + average: str = "macro", + zero_division: int = 0, +) -> float: + """ + Calculate the F1 score for multi-label classification using binarized labels. + + Args: + y_true (List[Set[Union[str, int]]]): List of sets of true labels. + y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels. + average (str, optional): Method to calculate F1 score, can be 'micro', 'macro', or 'weighted'. Defaults to 'macro'. + zero_division (int, optional): Value to return when there is a zero division case. Defaults to 0. + + Returns: + float: Calculated F1 score for multi-label classification. + + Raises: + AssertionError: If lengths of y_true and y_pred are not equal. + ValueError: If invalid averaging method is provided. + """ + assert len(y_true) == len(y_pred), "Lengths of y_true and y_pred must be equal." + + # Binarize the labels and get the unique class set + y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred) + + # Number of classes should remain consistent + num_classes = len(classes) + + if average == "micro": + true_positives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_true_bin[i][j] == y_pred_bin[i][j] == 1 + ) + false_positives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0 + ) + false_negatives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1 + ) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else zero_division + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else zero_division + ) + f1_score = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else zero_division + ) + + elif average in ["macro", "weighted"]: + f1_score = 0.0 + total_support = sum( + sum(y_true_bin[i][j] for i in range(len(y_true_bin))) + for j in range(num_classes) + ) + + for j in range(num_classes): + true_positives = sum( + 1 + for i in range(len(y_true_bin)) + if y_true_bin[i][j] == y_pred_bin[i][j] == 1 + ) + false_positives = sum( + 1 + for i in range(len(y_true_bin)) + if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0 + ) + false_negatives = sum( + 1 + for i in range(len(y_true_bin)) + if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1 + ) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else zero_division + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else zero_division + ) + + if precision + recall > 0: + class_f1_score = 2 * (precision * recall) / (precision + recall) + else: + class_f1_score = 0.0 + + # Support for the current class (how many times it appears in y_true) + support = sum(y_true_bin[i][j] for i in range(len(y_true_bin))) + + if average == "macro": + f1_score += class_f1_score / num_classes + elif average == "weighted": + # Normalize weights by dividing the support by the total number of labels + weight = support / total_support if total_support > 0 else 0 + f1_score += weight * class_f1_score + + else: + raise ValueError( + "Invalid averaging method. Must be 'micro', 'macro', or 'weighted'." + ) + + return min(f1_score, 1.0) # Ensure the F1 score is capped at 1.0