From b6e0d133b2f653a8dd232a8d5082b210e9bff5f4 Mon Sep 17 00:00:00 2001 From: Seonghun Noh Date: Sat, 14 Oct 2023 01:40:51 +0900 Subject: [PATCH] fix ptl_bugs in slu_models.py (#7689) * fix ptl_bugs in slu_models.py Signed-off-by: Seonghun Noh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change strategy to ddp_find_unused_parameters_true in slu example yaml Signed-off-by: Seonghun --------- Signed-off-by: Seonghun Noh Signed-off-by: Seonghun Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> --- .../conformer_transformer_large_bpe.yaml | 2 +- nemo/collections/asr/models/slu_models.py | 30 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/slu/speech_intent_slot/configs/conformer_transformer_large_bpe.yaml b/examples/slu/speech_intent_slot/configs/conformer_transformer_large_bpe.yaml index 10d59f2b5c0d..5d309f3cd193 100644 --- a/examples/slu/speech_intent_slot/configs/conformer_transformer_large_bpe.yaml +++ b/examples/slu/speech_intent_slot/configs/conformer_transformer_large_bpe.yaml @@ -174,7 +174,7 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: ddp_find_unused_parameters_true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. diff --git a/nemo/collections/asr/models/slu_models.py b/nemo/collections/asr/models/slu_models.py index 6df907334662..59323cfbfffe 100644 --- a/nemo/collections/asr/models/slu_models.py +++ b/nemo/collections/asr/models/slu_models.py @@ -285,7 +285,7 @@ def predict( predictions = self.sequence_generator.decode_semantics_from_tokens(pred_tokens) return predictions - def validation_step(self, batch, batch_idx, dataloader_idx=0): + def validation_pass(self, batch, batch_idx, dataloader_idx=0): if len(batch) == 4: signal, signal_len, semantics, semantics_len = batch else: @@ -327,19 +327,29 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): 'val_wer': wer, } + def validation_step(self, batch, batch_idx, dataloader_idx=0): + metrics = self.validation_pass(batch, batch_idx, dataloader_idx) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(metrics) + else: + self.validation_step_outputs.append(metrics) + return metrics + def test_step(self, batch, batch_idx, dataloader_idx=0): - logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx) - test_logs = { - 'test_loss': logs['val_loss'], - 'test_wer_num': logs['val_wer_num'], - 'test_wer_denom': logs['val_wer_denom'], - 'test_wer': logs['val_wer'], - } + logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) + test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(test_logs) + else: + self.test_step_outputs.append(test_logs) return test_logs def test_dataloader(self): - if self._test_dl is not None: - return self._test_dl + if self._test_dl is None: + # None dataloader no longer supported in PTL2.0 + self._test_dl = [] + + return self._test_dl def _setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: