Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langtest/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
IPython = importlib.import_module(LIB_NAME)
self.display = getattr(IPython.display, "display")
else:
raise ModuleNotFoundError(Errors.E023.format(LIB_NAME=LIB_NAME))
raise ModuleNotFoundError(Errors.E023(LIB_NAME=LIB_NAME))

def on_init_end(self, args, state, control, **kwargs):
model = kwargs["model"]
Expand Down
34 changes: 17 additions & 17 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def filter_curated_bias(
if item.test_type in tests_to_filter:
data.append(item)
logging.warning(
Warnings.W003.format(
Warnings.W003(
len_bias_data=len(bias_data),
len_samples_removed=len(bias_data) - len(data),
)
Expand Down Expand Up @@ -356,14 +356,14 @@ def _load_dataset(cls, custom_label: dict) -> str:
if "split" not in dataset_info:
if subset is None:
subset = list(dataset_info.keys())[0]
logging.warning(Warnings.W012.format(var1="subset", var2=subset))
logging.warning(Warnings.W012(var1="subset", var2=subset))
if split is None:
split = dataset_info[subset]["split"][0]
logging.warning(Warnings.W012.format(var1="split", var2=split))
logging.warning(Warnings.W012(var1="split", var2=split))

if subset not in dataset_info or split not in dataset_info[subset]["split"]:
raise ValueError(
Errors.E082.format(
Errors.E082(
subset=subset,
split=split,
dataset_name=dataset_name,
Expand All @@ -386,11 +386,11 @@ def _load_dataset(cls, custom_label: dict) -> str:
else:
if split is None:
split = dataset_info["split"][0]
logging.warning(Warnings.W012.format(var1="split", var2=split))
logging.warning(Warnings.W012(var1="split", var2=split))

if split not in dataset_info["split"]:
raise ValueError(
Errors.E083.format(
Errors.E083(
split=split,
dataset_name=dataset_name,
available_splits=", ".join(dataset_info["split"]),
Expand Down Expand Up @@ -449,7 +449,7 @@ def load_raw_data(self) -> List[Dict]:
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004.format(sent=sent))
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
Expand Down Expand Up @@ -491,7 +491,7 @@ def load_data(self) -> List[NERSample]:
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004.format(sent=sent))
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
Expand Down Expand Up @@ -565,7 +565,7 @@ def __token_validation(self, tokens: str) -> (bool, List[List[str]]): # type: i
token_list.append(tsplit)
valid_labels.append(tsplit[-1])
else:
logging.warning(Warnings.W008.format(sent=t))
logging.warning(Warnings.W008(sent=t))
return False, token_list

if valid_labels[0].startswith("I-"):
Expand Down Expand Up @@ -671,7 +671,7 @@ def __init__(self, file_path: Union[str, Dict], task: TaskManager, **kwargs) ->
if task_name in self.COLUMN_NAMES:
self.COLUMN_NAMES = self.COLUMN_NAMES[task_name]
elif "is_import" not in kwargs:
raise ValueError(Errors.E026.format(task=task))
raise ValueError(Errors.E026(task=task))

self.column_map = None
self.kwargs = kwargs
Expand Down Expand Up @@ -699,7 +699,7 @@ def load_raw_data(self, standardize_columns: bool = False) -> List[Dict]:

if feature_column not in df.columns or target_column not in df.columns:
raise ValueError(
Errors.E027.format(
Errors.E027(
feature_column=feature_column, target_column=target_column
)
)
Expand Down Expand Up @@ -810,7 +810,7 @@ def load_data(self) -> List[Sample]:
data.append(sample)

except Exception as e:
logging.warning(Warnings.W005.format(idx=idx, row_data=row_data, e=e))
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

return data
Expand Down Expand Up @@ -993,7 +993,7 @@ def _match_column_names(self, column_names: List[str]) -> Dict[str, str]:

if "text" in not_referenced_columns:
raise OSError(
Errors.E029.format(
Errors.E029(
valid_column_names=self.COLUMN_NAMES[self.task.task_name]["text"],
column_names=column_names,
)
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def _check_datasets_package(self):
dataset_module = importlib.import_module(self.LIB_NAME)
self.load_dataset = getattr(dataset_module, "load_dataset")
else:
raise ModuleNotFoundError(Errors.E023.format(LIB_NAME=self.LIB_NAME))
raise ModuleNotFoundError(Errors.E023(LIB_NAME=self.LIB_NAME))

def load_raw_data(
self,
Expand Down Expand Up @@ -1299,7 +1299,7 @@ def load_data(self) -> List[Sample]:
samples = getattr(self, method_name)()
return samples
else:
raise ValueError(Errors.E030.format(dataset_name=self.dataset_name))
raise ValueError(Errors.E030(dataset_name=self.dataset_name))

@staticmethod
def extract_data_with_equal_proportion(data_dict, total_samples):
Expand Down Expand Up @@ -1613,7 +1613,7 @@ def __init__(self, file_path: str, task: TaskManager, **kwargs) -> None:
if task.task_name in self.COLUMN_NAMES:
self.COLUMN_NAMES = self.COLUMN_NAMES[task.task_name]
elif "is_import" not in kwargs:
raise ValueError(Errors.E026.format(task=task))
raise ValueError(Errors.E026(task=task))

self.column_map = None
self.kwargs = kwargs
Expand Down Expand Up @@ -1690,7 +1690,7 @@ def load_data(self) -> List[Sample]:
data.append(sample)

except Exception as e:
logging.warning(Warnings.W005.format(idx=idx, row_data=row_data, e=e))
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

return data
Expand Down
2 changes: 1 addition & 1 deletion langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def process(sample: Sample, output_format: str, *args, **kwargs):
sample, *args, **kwargs
)
except KeyError:
raise NameError(Errors.E031.format(class_name=class_name))
raise NameError(Errors.E031(class_name=class_name))


class SequenceClassificationOutputFormatter(BaseFormatter, ABC):
Expand Down
4 changes: 2 additions & 2 deletions langtest/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, model="text-embedding-ada-002"):
self.openai = None
self._check_openai_package()
if not self.api_key:
raise ValueError(Errors.E032)
raise ValueError(Errors.E032())

self.openai.api_key = self.api_key

Expand All @@ -28,7 +28,7 @@ def _check_openai_package(self):
if try_import_lib(self.LIB_NAME):
self.openai = importlib.import_module(self.LIB_NAME)
else:
raise ModuleNotFoundError(Errors.E023.format(LIB_NAME=self.LIB_NAME))
raise ModuleNotFoundError(Errors.E023(LIB_NAME=self.LIB_NAME))

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
Expand Down
28 changes: 22 additions & 6 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ErrorsWithCodes(type):
This metaclass is used to create error and warning classes, such as Errors and Warnings.
"""

def __getattribute__(self, code):
def __getattribute__(self, code: str):
"""
Retrieve the error/warning message associated with a given code.

Expand All @@ -28,11 +28,27 @@ def __getattribute__(self, code):
Example:
error_message = Errors.E000
"""
from langtest.logger import logger

msg = super().__getattribute__(code)
if code.startswith("__"):
if code.startswith("__") or code.startswith("_") or code.startswith("get"):
return msg
else:
return "[{code}] {msg}".format(code=code, msg=msg)
def formatted_msg(**kwargs):
formatted_message = msg.format(**kwargs)
out = f"[{code}] {formatted_message}"
if code.startswith("E"):
logger.exception(out, exc_info=False)
elif code.startswith("W"):
logger.warning(out)
elif code.startswith("I"):
logger.info(out)
elif code.startswith("D"):
logger.debug(out)
elif code.startswith("C"):
logger.critical(out)
return formatted_message
return formatted_msg


class Warnings(metaclass=ErrorsWithCodes):
Expand Down Expand Up @@ -66,8 +82,8 @@ class Warnings(metaclass=ErrorsWithCodes):
W006 = ("target_column '{target_column}' not found in the dataset.")
W007 = ("'feature_column' '{feature_column}' not found in the dataset.")
W008 = ("Invalid or Missing label entries in the sentence: {sent}")
W009 = ("Removing samples where no transformation has been applied:\n")
W010 = ("- Test '{test}': {count} samples removed out of {total_sample}\n")
_W009 = ("Removing samples where no transformation has been applied:\n")
_W010 = ("- Test '{test}': {count} samples removed out of {total_sample}\n")
W011 = ("{class_name} successfully ran!")
W012 = ("You haven't provided the {var1}. Loading the default {var1}: {var2}")
W013 = ("Unable to find test_cases.pkl inside {save_dir}. Generating new testcases.")
Expand Down Expand Up @@ -270,7 +286,7 @@ def __init__(
supported_columns,
given_columns,
):
self.message = Errors.E077.format(
self.message = Errors.E077(
supported_columns=supported_columns, given_columns=given_columns
)
super().__init__(self.message)
Loading