diff --git a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py index 48c4c683222e..929f535cfc8d 100644 --- a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py +++ b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py @@ -1110,7 +1110,7 @@ def _check_label_ids_loaded_from_pkl( ) -> None: if not isinstance(pkl_punct_label_ids, dict): raise ValueError( - f"Punctuation label ids loaded from features file {self.features_pkl} has wrong type " + f"Punctuation label ids loaded from features file {self.features_pkl} have wrong type " f"{type(pkl_punct_label_ids)}" ) if parameter_punct_label_ids is not None: diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py index 5a5f6c025eea..5f6fa7f6164f 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import warnings from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -770,6 +771,39 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u 'punct_label_vocab_file': punct_label_vocab_file, 'capit_label_vocab_file': capit_label_vocab_file, } + if train: + number_of_batches_is_multiple_of = 1 + if self._trainer is None: + warnings.warn( + 'A model attribute `trainer` is not set before training dataset setting. If training is ' + 'resumed from checkpoint, then current epoch data loading can be distorted: some batches ' + 'may be processed several times and some can be not processed at all. `trainer.current_epoch`' + ' is used as random seed for shuffling batches. Now 0 will be used. If the ' + 'checkpoint was created not during initial epoch a shuffling of the dataset will ' + 'be different. You may try use `exp_manager()` function and ' + '`PunctuationCapitalizationModel.set_trainer()` method before ' + '`PunctuationCapitalizationModel.setup_training_data()` method.' + ) + batch_shuffling_random_seed = 0 + else: + batch_shuffling_random_seed = self._trainer.current_epoch + else: + batch_shuffling_random_seed = 0 + if self._trainer is None: + warnings.warn( + 'A model attribute `trainer` is not set before test or validation dataset setting. If more ' + 'than 1 GPU is used for testing, then some examples may be tested several times because ' + 'number of batches may be not evenly divisible by number of processes. This leads to ' + 'distortion of metrics. See more in description of `number_of_batches_is_multiple_of` ' + 'parameter of class `BertPunctuationCapitalizationDataset` initializer and ' + 'https://pytorch.org/docs/stable/data.html#multi-process-data-loading. You may try to use ' + '`PunctuationCapitalizationModel.set_trainer()` method before ' + '`PunctuationCapitalizationModel.setup_validation_data()` and ' + '`PunctuationCapitalizationModel.setup_test_data()` methods.' + ) + number_of_batches_is_multiple_of = 1 + else: + number_of_batches_is_multiple_of = self._trainer.num_nodes * self._trainer.num_devices dataset = BertPunctuationCapitalizationDataset( tokenizer=self.tokenizer, text_file=text_file, @@ -783,8 +817,8 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u num_samples=cfg.num_samples, tokens_in_batch=cfg.tokens_in_batch, n_jobs=cfg.n_jobs, - number_of_batches_is_multiple_of=1 if train else self.trainer.num_nodes * self.trainer.num_devices, - batch_shuffling_random_seed=self.trainer.global_step if train else 42, + number_of_batches_is_multiple_of=number_of_batches_is_multiple_of, + batch_shuffling_random_seed=batch_shuffling_random_seed, verbose=cfg.verbose, get_label_frequencies=cfg.get_label_frequences, cache_dir=cfg.cache_dir, diff --git a/tutorials/nlp/Punctuation_and_Capitalization.ipynb b/tutorials/nlp/Punctuation_and_Capitalization.ipynb index 8b4b91eff699..7765222bb03e 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization.ipynb @@ -990,14 +990,15 @@ " 'tokens_in_batch': 1024,\n", " },\n", ")\n", - "pretrained_model.setup_training_data()\n", - "pretrained_model.setup_validation_data()\n", "\n", "# and now we can create a PyTorch Lightning trainer and call `fit` again\n", "# for this tutorial we are setting fast_dev_run to True, and the trainer will run 1 training batch and 1 validation batch\n", "# for actual model training, disable the flag\n", "fast_dev_run = True\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', fast_dev_run=fast_dev_run)\n", + "pretrained_model.set_trainer(trainer)\n", + "pretrained_model.setup_training_data()\n", + "pretrained_model.setup_validation_data()\n", "trainer.fit(pretrained_model)" ] }