From 0c7c9b0c5490807eebfe2b74cc592e464923a15a Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Mon, 16 Sep 2024 22:36:24 +0530 Subject: [PATCH 1/2] Now handles the multi-label in accuracy tests. --- langtest/transform/accuracy.py | 29 +++++-- langtest/utils/custom_types/helpers.py | 8 +- langtest/utils/custom_types/output.py | 2 + langtest/utils/util_metrics.py | 105 ++++++++++++++++++++++++- 4 files changed, 133 insertions(+), 11 deletions(-) diff --git a/langtest/transform/accuracy.py b/langtest/transform/accuracy.py index eb78de8c6..507d42408 100644 --- a/langtest/transform/accuracy.py +++ b/langtest/transform/accuracy.py @@ -15,7 +15,11 @@ ) from langtest.utils.custom_types.helpers import default_user_prompt from langtest.errors import Errors -from langtest.utils.util_metrics import calculate_f1_score, classification_report +from langtest.utils.util_metrics import ( + calculate_f1_score, + classification_report, + classification_report_multi_label, +) class AccuracyTestFactory(ITests): @@ -151,6 +155,7 @@ def predict_ner(sample): y_true = y_true.apply(lambda x: x.split("-")[-1]) elif isinstance(raw_data_copy[0], SequenceClassificationSample): + is_mutli_label = raw_data_copy[0].expected_results.multi_label def predict_text_classification(sample): prediction = model.predict(sample.original) @@ -166,11 +171,16 @@ def predict_text_classification(sample): y_pred = pd.Series(raw_data_copy).apply( lambda x: [y.label for y in x.actual_results.predictions] ) - y_true = y_true.apply(lambda x: x[0]) - y_pred = y_pred.apply(lambda x: x[0]) - y_true = y_true.explode() - y_pred = y_pred.explode() + if is_mutli_label: + kwargs["is_multi_label"] = is_mutli_label + + else: + y_true = y_true.apply(lambda x: x[0]) + y_pred = y_pred.apply(lambda x: x[0]) + + y_true = y_true.explode() + y_pred = y_pred.explode() elif raw_data_copy[0].task == "question-answering": from ..utils.custom_types.helpers import build_qa_input, build_qa_prompt @@ -531,8 +541,13 @@ async def run( """ progress = kwargs.get("progress_bar", False) - - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): diff --git a/langtest/utils/custom_types/helpers.py b/langtest/utils/custom_types/helpers.py index fa9e43f61..b7439d69d 100644 --- a/langtest/utils/custom_types/helpers.py +++ b/langtest/utils/custom_types/helpers.py @@ -756,8 +756,12 @@ def prepare_model_response(self, data): if data[0].task == "text-classification": for sample in data: - sample.actual_results = sample.actual_results.predictions[0] - sample.expected_results = sample.expected_results.predictions[0] + if sample.expected_results.multi_label: + sample.actual_results = sample.actual_results + sample.expected_results = sample.expected_results + else: + sample.actual_results = sample.actual_results.predictions[0] + sample.expected_results = sample.expected_results.predictions[0] elif data[0].task == "ner": for sample in data: sample.actual_results = sample.actual_results.predictions diff --git a/langtest/utils/custom_types/output.py b/langtest/utils/custom_types/output.py index 6961e4b0f..0808e92bd 100644 --- a/langtest/utils/custom_types/output.py +++ b/langtest/utils/custom_types/output.py @@ -20,6 +20,8 @@ def to_str_list(self) -> str: def __str__(self) -> str: """String representation""" + if self.multi_label: + return self.to_str_list() labels = {elt.label: elt.score for elt in self.predictions} return f"SequenceClassificationOutput(predictions={labels})" diff --git a/langtest/utils/util_metrics.py b/langtest/utils/util_metrics.py index 4fd960a0f..76116094e 100644 --- a/langtest/utils/util_metrics.py +++ b/langtest/utils/util_metrics.py @@ -1,10 +1,13 @@ from collections import Counter -from typing import List, Union, Dict +from typing import List, Set, Union, Dict from ..errors import Errors def classification_report( - y_true: List[Union[str, int]], y_pred: List[Union[str, int]], zero_division: int = 0 + y_true: List[Union[str, int]], + y_pred: List[Union[str, int]], + zero_division: int = 0, + multi_label: bool = False, ) -> Dict[str, Dict[str, Union[float, int]]]: """Generate a classification report including precision, recall, f1-score, and support. @@ -170,3 +173,101 @@ def calculate_f1_score( else: raise ValueError(Errors.E074) return f1_score + + +def simple_multilabel_binarizer(y_true, y_pred): + """ + A simple implementation of a multilabel binarizer for y_true and y_pred. + + Args: + y_true (list of lists or sets): Actual labels for the data. + y_pred (list of lists or sets): Predicted labels for the data. + + Returns: + binarized_y_true (list of lists): Binary matrix of true labels. + binarized_y_pred (list of lists): Binary matrix of predicted labels. + classes (list): List of all unique classes (labels). + """ + # Get all unique labels (classes) from both y_true and y_pred + classes = sorted(set(label for labels in y_true + y_pred for label in labels)) + + # Binarize y_true and y_pred based on the classes + binarized_y_true = [ + [1 if label in labels else 0 for label in classes] for labels in y_true + ] + binarized_y_pred = [ + [1 if label in labels else 0 for label in classes] for labels in y_pred + ] + + return binarized_y_true, binarized_y_pred, classes + + +def classification_report_multi_label( + y_true: List[Set[Union[str, int]]], + y_pred: List[Set[Union[str, int]]], + zero_division: int = 0, +) -> Dict[str, Dict[str, Union[float, int]]]: + """ + Generate a classification report for multi-label classification. + + Args: + y_true (List[Set[Union[str, int]]]): List of sets of true labels. + y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels. + zero_division (int, optional): Specifies the value to return when there is a zero division case. Defaults to 0. + + Returns: + Dict[str, Dict[str, Union[float, int]]]: Classification report. + """ + # Binarize the multi-label data + y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred) + + # Initialize data structure for the report + report = {} + for i, class_label in enumerate(classes): + support = sum(row[i] for row in y_true_bin) + predicted_labels = sum(row[i] for row in y_pred_bin) + correct_predictions = sum( + 1 + for true_row, pred_row in zip(y_true_bin, y_pred_bin) + if true_row[i] == pred_row[i] == 1 + ) + + # Precision, recall, and F1 score calculations + if predicted_labels > 0: + precision = correct_predictions / predicted_labels + else: + precision = zero_division + + if support > 0: + recall = correct_predictions / support + else: + recall = zero_division + + if (precision + recall) > 0: + f1_score = (2 * precision * recall) / (precision + recall) + else: + f1_score = zero_division + + # Add stats to the report + report[class_label] = { + "precision": precision, + "recall": recall, + "f1-score": f1_score, + "support": support, + } + + # Compute macro averages + avg_precision = sum([metrics["precision"] for metrics in report.values()]) / len( + report + ) + avg_recall = sum([metrics["recall"] for metrics in report.values()]) / len(report) + avg_f1_score = sum([metrics["f1-score"] for metrics in report.values()]) / len(report) + + report["macro avg"] = { + "precision": avg_precision, + "recall": avg_recall, + "f1-score": avg_f1_score, + "support": len(y_true), + } + + return report From 54f235d4b0c58774fbc918562eb833fba71eebfb Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Mon, 16 Sep 2024 23:14:35 +0530 Subject: [PATCH 2/2] Refactor accuracy tests to handle multi-label classification --- langtest/transform/accuracy.py | 43 ++++++++-- langtest/utils/util_metrics.py | 139 ++++++++++++++++++++++++++++++--- 2 files changed, 168 insertions(+), 14 deletions(-) diff --git a/langtest/transform/accuracy.py b/langtest/transform/accuracy.py index 507d42408..c9f4ccfc5 100644 --- a/langtest/transform/accuracy.py +++ b/langtest/transform/accuracy.py @@ -17,6 +17,7 @@ from langtest.errors import Errors from langtest.utils.util_metrics import ( calculate_f1_score, + calculate_f1_score_multi_label, classification_report, classification_report_multi_label, ) @@ -384,7 +385,13 @@ async def run( y_pred (List[Any]): Predicted values """ progress = kwargs.get("progress_bar", False) - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): @@ -464,7 +471,13 @@ async def run( """ progress = kwargs.get("progress_bar", False) - df_metrics = classification_report(y_true, y_pred, zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + if is_multi_label: + df_metrics = classification_report_multi_label( + y_true, y_pred, zero_division=0 + ) + else: + df_metrics = classification_report(y_true, y_pred, zero_division=0) df_metrics.pop("macro avg") for idx, sample in enumerate(sample_list): @@ -614,8 +627,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) + is_multi_label = kwargs.get("is_multi_label", False) - f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0) + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="micro", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="micro", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) @@ -679,7 +698,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) - f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="macro", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="macro", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) @@ -741,7 +767,14 @@ async def run( """ progress = kwargs.get("progress_bar", False) - f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0) + is_multi_label = kwargs.get("is_multi_label", False) + + if is_multi_label: + f1 = calculate_f1_score_multi_label( + y_true, y_pred, average="weighted", zero_division=0 + ) + else: + f1 = calculate_f1_score(y_true, y_pred, average="weighted", zero_division=0) for sample in sample_list: sample.actual_results = MinScoreOutput(min_score=f1) diff --git a/langtest/utils/util_metrics.py b/langtest/utils/util_metrics.py index 76116094e..eb407d34d 100644 --- a/langtest/utils/util_metrics.py +++ b/langtest/utils/util_metrics.py @@ -188,18 +188,15 @@ def simple_multilabel_binarizer(y_true, y_pred): binarized_y_pred (list of lists): Binary matrix of predicted labels. classes (list): List of all unique classes (labels). """ - # Get all unique labels (classes) from both y_true and y_pred + # Ensure we collect unique classes from both y_true and y_pred classes = sorted(set(label for labels in y_true + y_pred for label in labels)) - # Binarize y_true and y_pred based on the classes - binarized_y_true = [ - [1 if label in labels else 0 for label in classes] for labels in y_true - ] - binarized_y_pred = [ - [1 if label in labels else 0 for label in classes] for labels in y_pred - ] + # Create a binary matrix for y_true and y_pred + y_true_bin = [[1 if label in labels else 0 for label in classes] for labels in y_true] + y_pred_bin = [[1 if label in labels else 0 for label in classes] for labels in y_pred] - return binarized_y_true, binarized_y_pred, classes + # Return the binarized labels and the consistent set of classes + return y_true_bin, y_pred_bin, classes def classification_report_multi_label( @@ -271,3 +268,127 @@ def classification_report_multi_label( } return report + + +def calculate_f1_score_multi_label( + y_true: List[Set[Union[str, int]]], + y_pred: List[Set[Union[str, int]]], + average: str = "macro", + zero_division: int = 0, +) -> float: + """ + Calculate the F1 score for multi-label classification using binarized labels. + + Args: + y_true (List[Set[Union[str, int]]]): List of sets of true labels. + y_pred (List[Set[Union[str, int]]]): List of sets of predicted labels. + average (str, optional): Method to calculate F1 score, can be 'micro', 'macro', or 'weighted'. Defaults to 'macro'. + zero_division (int, optional): Value to return when there is a zero division case. Defaults to 0. + + Returns: + float: Calculated F1 score for multi-label classification. + + Raises: + AssertionError: If lengths of y_true and y_pred are not equal. + ValueError: If invalid averaging method is provided. + """ + assert len(y_true) == len(y_pred), "Lengths of y_true and y_pred must be equal." + + # Binarize the labels and get the unique class set + y_true_bin, y_pred_bin, classes = simple_multilabel_binarizer(y_true, y_pred) + + # Number of classes should remain consistent + num_classes = len(classes) + + if average == "micro": + true_positives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_true_bin[i][j] == y_pred_bin[i][j] == 1 + ) + false_positives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0 + ) + false_negatives = sum( + 1 + for i in range(len(y_true_bin)) + for j in range(num_classes) + if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1 + ) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else zero_division + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else zero_division + ) + f1_score = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else zero_division + ) + + elif average in ["macro", "weighted"]: + f1_score = 0.0 + total_support = sum( + sum(y_true_bin[i][j] for i in range(len(y_true_bin))) + for j in range(num_classes) + ) + + for j in range(num_classes): + true_positives = sum( + 1 + for i in range(len(y_true_bin)) + if y_true_bin[i][j] == y_pred_bin[i][j] == 1 + ) + false_positives = sum( + 1 + for i in range(len(y_true_bin)) + if y_pred_bin[i][j] == 1 and y_true_bin[i][j] == 0 + ) + false_negatives = sum( + 1 + for i in range(len(y_true_bin)) + if y_pred_bin[i][j] == 0 and y_true_bin[i][j] == 1 + ) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else zero_division + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else zero_division + ) + + if precision + recall > 0: + class_f1_score = 2 * (precision * recall) / (precision + recall) + else: + class_f1_score = 0.0 + + # Support for the current class (how many times it appears in y_true) + support = sum(y_true_bin[i][j] for i in range(len(y_true_bin))) + + if average == "macro": + f1_score += class_f1_score / num_classes + elif average == "weighted": + # Normalize weights by dividing the support by the total number of labels + weight = support / total_support if total_support > 0 else 0 + f1_score += weight * class_f1_score + + else: + raise ValueError( + "Invalid averaging method. Must be 'micro', 'macro', or 'weighted'." + ) + + return min(f1_score, 1.0) # Ensure the F1 score is capped at 1.0