diff --git a/langtest/modelhandler/transformers_modelhandler.py b/langtest/modelhandler/transformers_modelhandler.py index 538d6fdb6..5055074de 100644 --- a/langtest/modelhandler/transformers_modelhandler.py +++ b/langtest/modelhandler/transformers_modelhandler.py @@ -10,8 +10,8 @@ SequenceClassificationOutput, TranslationOutput, ) - -from langchain import PromptTemplate +from langtest.utils.lib_manager import try_import_lib +import importlib class PretrainedModelForNER(_ModelHandler): @@ -357,6 +357,16 @@ def __init__(self, hub, model, **kwargs): ) self.model = model + def _check_langchain_package(self): + LIB_NAME = "langchain" + if try_import_lib(LIB_NAME): + langchain = importlib.import_module(LIB_NAME) + self.PromptTemplate = getattr(langchain, "PromptTemplate") + else: + raise ModuleNotFoundError( + f"The '{LIB_NAME}' package is not installed. Please install it using 'pip install {LIB_NAME}'." + ) + @staticmethod def load_model(hub: str, path: str, **kwargs) -> "Pipeline": """Load the QA model into the `model` attribute. @@ -383,7 +393,7 @@ def predict(self, text: Union[str, dict], prompt: dict, **kwargs) -> str: Returns: str: Output model for QA tasks """ - prompt_template = PromptTemplate(**prompt) + prompt_template = self.PromptTemplate(**prompt) p = prompt_template.format(**text) prediction = self.model(p, **kwargs) return prediction[0]["generated_text"][len(p) :]