diff --git a/langtest/langtest.py b/langtest/langtest.py index e30ab002f..f3e1f51d6 100644 --- a/langtest/langtest.py +++ b/langtest/langtest.py @@ -1758,6 +1758,7 @@ def get_leaderboard( category=False, split_wise=False, test_wise=False, + rank_by: Union[str, list] = "Avg", *args, **kwargs, ): @@ -1776,15 +1777,15 @@ def get_leaderboard( if indices or columns: return leaderboard.custom_wise(indices, columns) if category: - return leaderboard.category_wise() + return leaderboard.category_wise(rank_by=rank_by) if test_wise: - return leaderboard.test_wise() + return leaderboard.test_wise(rank_by=rank_by) if split_wise: - return leaderboard.split_wise() + return leaderboard.split_wise(rank_by=rank_by) - return leaderboard.default() + return leaderboard.default(rank_by=rank_by) def __temp_generate(self, *args, **kwargs): """Temporary function to generate the testcases.""" diff --git a/langtest/utils/benchmark_utils.py b/langtest/utils/benchmark_utils.py index 4b07dae65..f3e08f839 100644 --- a/langtest/utils/benchmark_utils.py +++ b/langtest/utils/benchmark_utils.py @@ -1,5 +1,5 @@ import os -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Union import pandas as pd @@ -35,10 +35,16 @@ def __init__( """ self.summary = Summary(path, *args, **kwargs) - def default(self): + def default(self, rank_by: Union[str, list] = "Avg"): """ Get the score board for the models """ + # check if the rank_by is a string + if isinstance(rank_by, str): + rank_by = [rank_by] + + ascending = [False] * len(rank_by) + df = self.summary.summary_df df = self.__drop_duplicates(df) pvt_table = df.pivot_table( @@ -47,7 +53,7 @@ def default(self): # mean column pvt_table.insert(0, "Avg", pvt_table.mean(axis=1)) - pvt_table = pvt_table.sort_values(by=["model", "Avg"], ascending=[True, False]) + pvt_table = pvt_table.sort_values(by=rank_by, ascending=ascending) # reset the index and fill the NaN values pvt_table = pvt_table.rename_axis(None, axis=1).reset_index() @@ -55,11 +61,17 @@ def default(self): return pvt_table - def split_wise(self): + def split_wise(self, rank_by: Union[str, list] = "Avg"): """ Get the score board for the models by test type """ + # check if the rank_by is a string + if isinstance(rank_by, str): + rank_by = [rank_by] + + ascending = [False] * len(rank_by) + df = self.summary.summary_df df = self.__drop_duplicates(df) pvt_table = df.pivot_table( @@ -70,51 +82,77 @@ def split_wise(self): # mean column pvt_table.insert(0, "Avg", pvt_table.mean(axis=1)) - pvt_table = pvt_table.sort_values(by=["model", "Avg"], ascending=[True, False]) + pvt_table = pvt_table.sort_values(by=rank_by, ascending=ascending) pvt_table = pvt_table.fillna("-") return pvt_table - def test_wise(self): + def test_wise(self, rank_by: Union[str, list] = "Avg"): """ Get the score board for the models by test type """ + # check if the rank_by is a string + if isinstance(rank_by, str): + rank_by = [rank_by] + + # check if the test_type in the rank_by + if "test_type" not in rank_by: + rank_by.insert(0, "test_type") + rank_by.insert(0, "category") + + ascending = [True, True] + [False] * (len(rank_by) - 2) + df = self.summary.summary_df df = self.__drop_duplicates(df) pvt_table = df.pivot_table( - index=["model", "test_type"], columns=["dataset_name"], values="score" + index=["category", "test_type", "model"], + columns=["dataset_name"], + values="score", ) # mean column pvt_table.insert(0, "Avg", pvt_table.mean(axis=1)) - pvt_table = pvt_table.sort_values(by=["model", "Avg"], ascending=[True, False]) + pvt_table = pvt_table.sort_values(by=rank_by, ascending=ascending) pvt_table = pvt_table.fillna("-") return pvt_table - def category_wise(self): + def category_wise(self, rank_by: Union[str, list] = "Avg"): """ Get the score board for the models by category """ + # check if the rank_by is a string + if isinstance(rank_by, str): + rank_by = [rank_by] + + ascending = [False] * len(rank_by) + df = self.summary.summary_df df = self.__drop_duplicates(df) pvt_table = df.pivot_table( - index=["model", "category"], columns=["dataset_name"], values="score" + index=["category", "model"], columns=["dataset_name"], values="score" ) pvt_table.insert(0, "Avg", pvt_table.mean(axis=1)) - pvt_table = pvt_table.sort_values(by=["model", "Avg"], ascending=[True, False]) + pvt_table = pvt_table.sort_values(by=rank_by, ascending=ascending) pvt_table = pvt_table.fillna("-") - pvt_table = pvt_table.rename_axis(None, axis=1).reset_index() return pvt_table - def custom_wise(self, indices: list, columns: list = []): + def custom_wise( + self, indices: list, columns: list = [], rank_by: Union[str, list] = "Avg" + ): """ Get the score board for the models by custom group """ + # check if the rank_by is a string + if isinstance(rank_by, str): + rank_by = [rank_by] + + ascending = [False] * len(rank_by) + df = self.summary.summary_df df = self.__drop_duplicates(df) pvt_table = df.pivot_table( @@ -125,8 +163,7 @@ def custom_wise(self, indices: list, columns: list = []): ) pvt_table.insert(0, "Avg", pvt_table.mean(axis=1)) pvt_table = pvt_table.fillna("-") - pvt_table = pvt_table.sort_values(by=["model", "Avg"], ascending=[True, False]) - # pvt_table = pvt_table.rename_axis(None, axis=1).reset_index() + pvt_table = pvt_table.sort_values(by=rank_by, ascending=ascending) return pvt_table