diff --git a/Jenkinsfile b/Jenkinsfile index 5d81a57c04c9..4f9220da1fc6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3485,7 +3485,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ - trainer.limit_val_batches=2 \ + trainer.limit_val_batches=1.0 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=3 \ trainer.precision=16 \ @@ -3520,7 +3520,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ - trainer.limit_val_batches=2 \ + trainer.limit_val_batches=1.0 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=6 \ trainer.precision=16 \ diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index f977846477b0..6818f99d0e4f 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -81,9 +81,12 @@ def __len__(self): num_available_samples: int = self.total_samples - self.consumed_samples if self.global_batch_size is not None: if self.drop_last: - return num_available_samples // self.global_batch_size + num_global_batches = num_available_samples // self.global_batch_size else: - return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and + # num of batches fetched (as training step fetches in terms of micro batches) + return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) else: return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 @@ -162,9 +165,12 @@ def __len__(self): num_available_samples = active_total_samples - self.consumed_samples % active_total_samples if self.global_batch_size is not None: if self.drop_last: - return num_available_samples // self.global_batch_size + num_global_batches = num_available_samples // self.global_batch_size else: - return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and + # num of batches fetched (as training step fetches in terms of micro batches) + return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) else: if self.drop_last: return num_available_samples // self.micro_batch_times_data_parallel_size diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 5321a307b2c4..6a2ea80ec764 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -27,6 +27,7 @@ from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION @@ -322,9 +323,37 @@ def _reconfigure_val_batches(self): """ Reconfigure trainer.limit_val_batches for pretraining """ + # Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches if isinstance(self.trainer.limit_val_batches, int): - # Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches self.trainer.limit_val_batches *= get_num_microbatches() + else: + assert isinstance(self.trainer.limit_val_batches, float) + # Don't reconfigure if limit_val_batches is 0.0 + if self.trainer.limit_val_batches == 0.0: + return + # len(self._validation_dl) returns len as num of microbatches + val_len_in_micro_batches = len(self._validation_dl) + if self._validation_ds is not None and len(self._validation_dl) != float("inf"): + if self.trainer.limit_val_batches == 1.0: + self.trainer.limit_val_batches = val_len_in_micro_batches + else: + limit_val_micro_batches = int(val_len_in_micro_batches * self.trainer.limit_val_batches) + if limit_val_micro_batches == 0 and self.trainer.limit_val_batches > 0.0: + min_percentage = 1.0 / len(self._validation_dl) + raise MisconfigurationException( + f"You requested to check {self.trainer.limit_val_batches} of the val_dataloader but" + f" {self.trainer.limit_val_batches} * {len(self._validation_dl)} < 1. Please increase the" + f" `limit_val_batches` argument. Try at least" + f" `limit_val_batches={min_percentage}`" + ) + # Make sure trainer.limit_val_batches is a multiple of num of microbatches + if limit_val_micro_batches < get_num_microbatches(): + self.trainer.limit_val_batches = get_num_microbatches() + else: + self.trainer.limit_val_batches = ( + limit_val_micro_batches - limit_val_micro_batches % get_num_microbatches() + ) + # Override num sanity steps to be a multiple of num of microbatches self.trainer.num_sanity_val_steps *= get_num_microbatches() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 752696ac8faa..950ce534e9bc 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1172,15 +1172,11 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): return loss def build_train_valid_test_datasets(self): - # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step - self._reconfigure_val_batches() - logging.info('Building GPT datasets.') if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") + logging.info('Building GPT datasets.') global_batch_size = self.cfg.global_batch_size max_train_steps = self.trainer.max_steps - eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches - test_iters = self.trainer.limit_test_batches # Add extra FIM tokens to tokenizer if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron': @@ -1188,16 +1184,12 @@ def build_train_valid_test_datasets(self): fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens}) - train_valid_test_num_samples = [ - max_train_steps * global_batch_size, - eval_iters * global_batch_size, - test_iters * global_batch_size, - ] - - if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[ - 1 - ] = 1 # This is to make sure we only have one epoch on every validation iteration + # The line below exploits a quirk in mcore dataset construction, to make number of epochs for validation and test equal to 1 + # The mcore dataset implementation uses the number N we provide via train_valid_test_num_samples to derive parameter E such that + # E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d) + # Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below). + # Setting N = 1 we force E to be 1 as well + train_valid_test_num_samples = [max_train_steps * global_batch_size, 1, 1] mock_dataset = self.cfg.data.get("mock_dataset", False) kwargs = { @@ -1325,6 +1317,8 @@ def setup(self, stage=None): self.setup_training_data(self.cfg.data) self.setup_validation_data(self.cfg.data) self.setup_test_data(self.cfg.data) + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + self._reconfigure_val_batches() if stage == 'fit': self.initialize_last_rank_embeddings()