-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Handle float limit_val_batches #8426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfa1df6
d859e67
e39a211
a1be065
0c7fd6e
85e22e7
5ad42ec
c76cf6d
b34e708
61e23fa
35c4c7d
70fdf34
49f6148
b9c666a
90931f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from pytorch_lightning.plugins.precision import MixedPrecisionPlugin | ||
| from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator | ||
| from pytorch_lightning.trainer.trainer import Trainer | ||
| from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
|
||
| from nemo.collections.nlp.models.nlp_model import NLPModel | ||
| from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION | ||
|
|
@@ -322,9 +323,37 @@ def _reconfigure_val_batches(self): | |
| """ | ||
| Reconfigure trainer.limit_val_batches for pretraining | ||
| """ | ||
| # Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches | ||
| if isinstance(self.trainer.limit_val_batches, int): | ||
| # Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches | ||
| self.trainer.limit_val_batches *= get_num_microbatches() | ||
| else: | ||
| assert isinstance(self.trainer.limit_val_batches, float) | ||
| # Don't reconfigure if limit_val_batches is 0.0 | ||
| if self.trainer.limit_val_batches == 0.0: | ||
| return | ||
| # len(self._validation_dl) returns len as num of microbatches | ||
| val_len_in_micro_batches = len(self._validation_dl) | ||
| if self._validation_ds is not None and len(self._validation_dl) != float("inf"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do infinite validation sets work in nemo? Is this even a valid approach? Shouldn't this be caught on dataset on argument layer, even before dataset construction?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ported that condition check over from PTL as they check if length of val dataloader is not inf here while casting float limit_val_batches to int.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code now does the round down logic only if limit_val_batches > 0.0 and less than 1.0.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I understand this, we have 5 cases if
@athitten do I understand this correctly? Doesn't it hang if Can you write it as a series of if/elif cases? It would be way easier to read instead of following nested ifs.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, it does not hang when |
||
| if self.trainer.limit_val_batches == 1.0: | ||
| self.trainer.limit_val_batches = val_len_in_micro_batches | ||
| else: | ||
| limit_val_micro_batches = int(val_len_in_micro_batches * self.trainer.limit_val_batches) | ||
| if limit_val_micro_batches == 0 and self.trainer.limit_val_batches > 0.0: | ||
| min_percentage = 1.0 / len(self._validation_dl) | ||
| raise MisconfigurationException( | ||
| f"You requested to check {self.trainer.limit_val_batches} of the val_dataloader but" | ||
| f" {self.trainer.limit_val_batches} * {len(self._validation_dl)} < 1. Please increase the" | ||
| f" `limit_val_batches` argument. Try at least" | ||
| f" `limit_val_batches={min_percentage}`" | ||
| ) | ||
| # Make sure trainer.limit_val_batches is a multiple of num of microbatches | ||
| if limit_val_micro_batches < get_num_microbatches(): | ||
| self.trainer.limit_val_batches = get_num_microbatches() | ||
| else: | ||
| self.trainer.limit_val_batches = ( | ||
| limit_val_micro_batches - limit_val_micro_batches % get_num_microbatches() | ||
| ) | ||
|
|
||
| # Override num sanity steps to be a multiple of num of microbatches | ||
| self.trainer.num_sanity_val_steps *= get_num_microbatches() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1172,32 +1172,24 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): | |
| return loss | ||
|
|
||
| def build_train_valid_test_datasets(self): | ||
| # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step | ||
| self._reconfigure_val_batches() | ||
| logging.info('Building GPT datasets.') | ||
| if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): | ||
| raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") | ||
| logging.info('Building GPT datasets.') | ||
| global_batch_size = self.cfg.global_batch_size | ||
| max_train_steps = self.trainer.max_steps | ||
| eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches | ||
| test_iters = self.trainer.limit_test_batches | ||
|
|
||
| # Add extra FIM tokens to tokenizer | ||
| if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron': | ||
| fim_tokens = self.cfg.data.fim.extra_tokens | ||
| fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] | ||
| self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens}) | ||
|
|
||
| train_valid_test_num_samples = [ | ||
| max_train_steps * global_batch_size, | ||
| eval_iters * global_batch_size, | ||
| test_iters * global_batch_size, | ||
| ] | ||
|
|
||
| if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): | ||
| train_valid_test_num_samples[ | ||
| 1 | ||
| ] = 1 # This is to make sure we only have one epoch on every validation iteration | ||
| # The line below exploits a quirk in mcore dataset construction, to make number of epochs for validation and test equal to 1 | ||
| # The mcore dataset implementation uses the number N we provide via train_valid_test_num_samples to derive parameter E such that | ||
| # E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d) | ||
| # Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below). | ||
| # Setting N = 1 we force E to be 1 as well | ||
| train_valid_test_num_samples = [max_train_steps * global_batch_size, 1, 1] | ||
|
|
||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this is currently causing an index error with blended dataset from mcore. Also I couldn't understand the purpose of having this. Waiting for @shanmugamr1992 's comments on it.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving here a comment from an offline conversation with the explanation:
|
||
| mock_dataset = self.cfg.data.get("mock_dataset", False) | ||
jbaczek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| kwargs = { | ||
|
|
@@ -1325,6 +1317,8 @@ def setup(self, stage=None): | |
| self.setup_training_data(self.cfg.data) | ||
| self.setup_validation_data(self.cfg.data) | ||
| self.setup_test_data(self.cfg.data) | ||
| # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step | ||
| self._reconfigure_val_batches() | ||
|
|
||
| if stage == 'fit': | ||
| self.initialize_last_rank_embeddings() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a check in line below:
self._validation_ds is not None, but in this line we assume existence ofself._validation_dl. Can data loader be created without a dataset?