Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def _check_label_ids_loaded_from_pkl(
) -> None:
if not isinstance(pkl_punct_label_ids, dict):
raise ValueError(
f"Punctuation label ids loaded from features file {self.features_pkl} has wrong type "
f"Punctuation label ids loaded from features file {self.features_pkl} have wrong type "
f"{type(pkl_punct_label_ids)}"
)
if parameter_punct_label_ids is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import warnings
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -770,6 +771,39 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u
'punct_label_vocab_file': punct_label_vocab_file,
'capit_label_vocab_file': capit_label_vocab_file,
}
if train:
number_of_batches_is_multiple_of = 1
if self._trainer is None:
warnings.warn(
'A model attribute `trainer` is not set before training dataset setting. If training is '
'resumed from checkpoint, then current epoch data loading can be distorted: some batches '
'may be processed several times and some can be not processed at all. `trainer.current_epoch`'
' is used as random seed for shuffling batches. Now 0 will be used. If the '
'checkpoint was created not during initial epoch a shuffling of the dataset will '
'be different. You may try use `exp_manager()` function and '
'`PunctuationCapitalizationModel.set_trainer()` method before '
'`PunctuationCapitalizationModel.setup_training_data()` method.'
)
batch_shuffling_random_seed = 0
else:
batch_shuffling_random_seed = self._trainer.current_epoch
else:
batch_shuffling_random_seed = 0
if self._trainer is None:
warnings.warn(
'A model attribute `trainer` is not set before test or validation dataset setting. If more '
'than 1 GPU is used for testing, then some examples may be tested several times because '
'number of batches may be not evenly divisible by number of processes. This leads to '
'distortion of metrics. See more in description of `number_of_batches_is_multiple_of` '
'parameter of class `BertPunctuationCapitalizationDataset` initializer and '
'https://pytorch.org/docs/stable/data.html#multi-process-data-loading. You may try to use '
'`PunctuationCapitalizationModel.set_trainer()` method before '
'`PunctuationCapitalizationModel.setup_validation_data()` and '
'`PunctuationCapitalizationModel.setup_test_data()` methods.'
)
number_of_batches_is_multiple_of = 1
else:
number_of_batches_is_multiple_of = self._trainer.num_nodes * self._trainer.num_devices
dataset = BertPunctuationCapitalizationDataset(
tokenizer=self.tokenizer,
text_file=text_file,
Expand All @@ -783,8 +817,8 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u
num_samples=cfg.num_samples,
tokens_in_batch=cfg.tokens_in_batch,
n_jobs=cfg.n_jobs,
number_of_batches_is_multiple_of=1 if train else self.trainer.num_nodes * self.trainer.num_devices,
batch_shuffling_random_seed=self.trainer.global_step if train else 42,
number_of_batches_is_multiple_of=number_of_batches_is_multiple_of,
batch_shuffling_random_seed=batch_shuffling_random_seed,
verbose=cfg.verbose,
get_label_frequencies=cfg.get_label_frequences,
cache_dir=cfg.cache_dir,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/nlp/Punctuation_and_Capitalization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -990,14 +990,15 @@
" 'tokens_in_batch': 1024,\n",
" },\n",
")\n",
"pretrained_model.setup_training_data()\n",
"pretrained_model.setup_validation_data()\n",
"\n",
"# and now we can create a PyTorch Lightning trainer and call `fit` again\n",
"# for this tutorial we are setting fast_dev_run to True, and the trainer will run 1 training batch and 1 validation batch\n",
"# for actual model training, disable the flag\n",
"fast_dev_run = True\n",
"trainer = pl.Trainer(devices=1, accelerator='gpu', fast_dev_run=fast_dev_run)\n",
"pretrained_model.set_trainer(trainer)\n",
"pretrained_model.setup_training_data()\n",
"pretrained_model.setup_validation_data()\n",
"trainer.fit(pretrained_model)"
]
}
Expand Down