From 4b261de11c22e8836c0c213934bbd2500abcd661 Mon Sep 17 00:00:00 2001 From: Prikshit7766 Date: Fri, 1 Sep 2023 16:21:16 +0530 Subject: [PATCH 1/3] fix import PromptTemplate --- .../modelhandler/transformers_modelhandler.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/langtest/modelhandler/transformers_modelhandler.py b/langtest/modelhandler/transformers_modelhandler.py index 538d6fdb6..04e8d652b 100644 --- a/langtest/modelhandler/transformers_modelhandler.py +++ b/langtest/modelhandler/transformers_modelhandler.py @@ -10,8 +10,8 @@ SequenceClassificationOutput, TranslationOutput, ) - -from langchain import PromptTemplate +from langtest.utils.lib_manager import try_import_lib +import importlib class PretrainedModelForNER(_ModelHandler): @@ -345,6 +345,15 @@ class PretrainedModelForQA(_ModelHandler): model (transformers.pipeline.Pipeline): Pretrained HuggingFace QA pipeline for predictions. """ + LIB_NAME = "langchain" + if try_import_lib(LIB_NAME): + langchain = importlib.import_module(LIB_NAME) + PromptTemplate = getattr(langchain, "PromptTemplate") + else: + raise ModuleNotFoundError( + f"The '{LIB_NAME}' package is not installed. Please install it using 'pip install {LIB_NAME}'." + ) + def __init__(self, hub, model, **kwargs): """Constructor method @@ -383,7 +392,7 @@ def predict(self, text: Union[str, dict], prompt: dict, **kwargs) -> str: Returns: str: Output model for QA tasks """ - prompt_template = PromptTemplate(**prompt) + prompt_template = self.PromptTemplate(**prompt) p = prompt_template.format(**text) prediction = self.model(p, **kwargs) return prediction[0]["generated_text"][len(p) :] From 01ab6fd1687ad289ef5af6aa60a7fdf7b501ebfc Mon Sep 17 00:00:00 2001 From: Prikshit7766 Date: Fri, 1 Sep 2023 16:43:15 +0530 Subject: [PATCH 2/3] add _check_langchain_package --- .../modelhandler/transformers_modelhandler.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/langtest/modelhandler/transformers_modelhandler.py b/langtest/modelhandler/transformers_modelhandler.py index 04e8d652b..d399e3b87 100644 --- a/langtest/modelhandler/transformers_modelhandler.py +++ b/langtest/modelhandler/transformers_modelhandler.py @@ -345,15 +345,6 @@ class PretrainedModelForQA(_ModelHandler): model (transformers.pipeline.Pipeline): Pretrained HuggingFace QA pipeline for predictions. """ - LIB_NAME = "langchain" - if try_import_lib(LIB_NAME): - langchain = importlib.import_module(LIB_NAME) - PromptTemplate = getattr(langchain, "PromptTemplate") - else: - raise ModuleNotFoundError( - f"The '{LIB_NAME}' package is not installed. Please install it using 'pip install {LIB_NAME}'." - ) - def __init__(self, hub, model, **kwargs): """Constructor method @@ -366,6 +357,16 @@ def __init__(self, hub, model, **kwargs): ) self.model = model + def _check_langchain_package(self): + LIB_NAME = "langchain" + if try_import_lib(LIB_NAME): + langchain = importlib.import_module(LIB_NAME) + PromptTemplate = getattr(langchain, "PromptTemplate") + else: + raise ModuleNotFoundError( + f"The '{LIB_NAME}' package is not installed. Please install it using 'pip install {LIB_NAME}'." + ) + @staticmethod def load_model(hub: str, path: str, **kwargs) -> "Pipeline": """Load the QA model into the `model` attribute. From a955472bf144d2a340995f8b6ee3904056541e25 Mon Sep 17 00:00:00 2001 From: Prikshit7766 Date: Fri, 1 Sep 2023 16:58:51 +0530 Subject: [PATCH 3/3] minor fix --- langtest/modelhandler/transformers_modelhandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langtest/modelhandler/transformers_modelhandler.py b/langtest/modelhandler/transformers_modelhandler.py index d399e3b87..5055074de 100644 --- a/langtest/modelhandler/transformers_modelhandler.py +++ b/langtest/modelhandler/transformers_modelhandler.py @@ -361,7 +361,7 @@ def _check_langchain_package(self): LIB_NAME = "langchain" if try_import_lib(LIB_NAME): langchain = importlib.import_module(LIB_NAME) - PromptTemplate = getattr(langchain, "PromptTemplate") + self.PromptTemplate = getattr(langchain, "PromptTemplate") else: raise ModuleNotFoundError( f"The '{LIB_NAME}' package is not installed. Please install it using 'pip install {LIB_NAME}'."