diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index ffae75ed5a34..67448badb43a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -436,6 +436,12 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline(): + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] + with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks self.cfg.new_tasks = current_new_tasks diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index add7c898c80c..5b083ed86b93 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -707,6 +707,12 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process: + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] + with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks self.cfg.new_tasks = current_new_tasks