From 2f5c424e68cd42e09f1276cfe594307628179520 Mon Sep 17 00:00:00 2001 From: Virginia Adams <78445382+vadam5@users.noreply.github.com> Date: Mon, 14 Nov 2022 11:10:13 -0800 Subject: [PATCH] Fix for prompt table restore error (#5393) * Fix for prompt table restore error Signed-off-by: Virginia Adams * Added more saftey checks Signed-off-by: Virginia Adams * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added more condition checks Signed-off-by: Virginia Adams * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Virginia Adams Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../megatron_base_prompt_learning_model.py | 6 ++++++ .../language_modeling/megatron_gpt_prompt_learning_model.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index ffae75ed5a34..67448badb43a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -436,6 +436,12 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline(): + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] + with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks self.cfg.new_tasks = current_new_tasks diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index add7c898c80c..5b083ed86b93 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -707,6 +707,12 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process: + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] + with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks self.cfg.new_tasks = current_new_tasks