Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
trainer.limit_val_batches=2 \
trainer.limit_val_batches=1.0 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=3 \
trainer.precision=16 \
Expand Down Expand Up @@ -3520,7 +3520,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
trainer.limit_val_batches=2 \
trainer.limit_val_batches=1.0 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=16 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
num_global_batches = num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

Expand Down Expand Up @@ -162,9 +165,12 @@ def __len__(self):
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
num_global_batches = num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
if self.drop_last:
return num_available_samples // self.micro_batch_times_data_parallel_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION
Expand Down Expand Up @@ -322,9 +323,37 @@ def _reconfigure_val_batches(self):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
if isinstance(self.trainer.limit_val_batches, int):
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
self.trainer.limit_val_batches *= get_num_microbatches()
else:
assert isinstance(self.trainer.limit_val_batches, float)
# Don't reconfigure if limit_val_batches is 0.0
if self.trainer.limit_val_batches == 0.0:
return
# len(self._validation_dl) returns len as num of microbatches
val_len_in_micro_batches = len(self._validation_dl)
Copy link
Collaborator

@jbaczek jbaczek Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check in line below: self._validation_ds is not None, but in this line we assume existence of self._validation_dl. Can data loader be created without a dataset?

if self._validation_ds is not None and len(self._validation_dl) != float("inf"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do infinite validation sets work in nemo? Is this even a valid approach? Shouldn't this be caught on dataset on argument layer, even before dataset construction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ported that condition check over from PTL as they check if length of val dataloader is not inf here while casting float limit_val_batches to int.
I don't think NeMo has any cases for infinite validation sets.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code now does the round down logic only if limit_val_batches > 0.0 and less than 1.0.

Copy link
Collaborator

@jbaczek jbaczek Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand this, we have 5 cases if limit_val_batches is a float:

  1. limit_val_batches == 0.0. Then we stick with this value and let PTL skip validation
  2. limit_val_batches == 1.0. Then we set self.trainer.limit_val_batches = len(self._validation_dl), even if it's not divisible by get_num_microbatches()
  3. limit_val_batches * len(self._validation_dl) < 1. Then we raise an exception
  4. get_num_microbaches() > limit_val_batches * len(self._validation_dl) > 1. Then we fix limit_val_batches = get_num_microbaches() to run at least one iteration.
  5. limit_val_batches * len(self._validation_dl) > get_num_microbaches(). Then we round down to an integer multiple of get_num_microbatches()

@athitten do I understand this correctly? Doesn't it hang if limit_val_batches == 1.0?

Can you write it as a series of if/elif cases? It would be way easier to read instead of following nested ifs.

Copy link
Collaborator Author

@athitten athitten Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it does not hang when limit_val_batches == 1.0. When limit_val_batches == 1.0, we use the full len of the dataloader as the limit_val_batches and since the len(dataloader) is already in microbatches, it is divisible by get_num_microbatches(). The situation where it needs to be ensured to be a multiple of get_num_microbatches(), arises when 0<limit_val_batches<1.0. Basically a fraction.

if self.trainer.limit_val_batches == 1.0:
self.trainer.limit_val_batches = val_len_in_micro_batches
else:
limit_val_micro_batches = int(val_len_in_micro_batches * self.trainer.limit_val_batches)
if limit_val_micro_batches == 0 and self.trainer.limit_val_batches > 0.0:
min_percentage = 1.0 / len(self._validation_dl)
raise MisconfigurationException(
f"You requested to check {self.trainer.limit_val_batches} of the val_dataloader but"
f" {self.trainer.limit_val_batches} * {len(self._validation_dl)} < 1. Please increase the"
f" `limit_val_batches` argument. Try at least"
f" `limit_val_batches={min_percentage}`"
)
# Make sure trainer.limit_val_batches is a multiple of num of microbatches
if limit_val_micro_batches < get_num_microbatches():
self.trainer.limit_val_batches = get_num_microbatches()
else:
self.trainer.limit_val_batches = (
limit_val_micro_batches - limit_val_micro_batches % get_num_microbatches()
)

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,32 +1172,24 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
return loss

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info('Building GPT datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
logging.info('Building GPT datasets.')
global_batch_size = self.cfg.global_batch_size
max_train_steps = self.trainer.max_steps
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches

# Add extra FIM tokens to tokenizer
if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron':
fim_tokens = self.cfg.data.fim.extra_tokens
fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod]
self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens})

train_valid_test_num_samples = [
max_train_steps * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size,
]

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
train_valid_test_num_samples[
1
] = 1 # This is to make sure we only have one epoch on every validation iteration
# The line below exploits a quirk in mcore dataset construction, to make number of epochs for validation and test equal to 1
# The mcore dataset implementation uses the number N we provide via train_valid_test_num_samples to derive parameter E such that
# E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d)
# Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below).
# Setting N = 1 we force E to be 1 as well
train_valid_test_num_samples = [max_train_steps * global_batch_size, 1, 1]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this is currently causing an index error with blended dataset from mcore. Also I couldn't understand the purpose of having this. Waiting for @shanmugamr1992 's comments on it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving here a comment from an offline conversation with the explanation:

This is a hack that uses quirk in the dataset implementation that causes the behaviour explained in the comment.
This array is used to supply the N value in the E = argmin_e e * len(data) >= N formula. Then E is used to construct indices which are then used as an internal representation of dataset.
The dataset is not trimmed to N after the construction, but to the E*len(data) . So setting this to 1, we are sure that E is exactly 1, which gives us a single iteration over a validation split.
[...]
In plain words: the number of epochs E is the smallest integer number that when multiplied by length of the data is greater than requested number of samples

mock_dataset = self.cfg.data.get("mock_dataset", False)
kwargs = {
Expand Down Expand Up @@ -1325,6 +1317,8 @@ def setup(self, stage=None):
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()

if stage == 'fit':
self.initialize_last_rank_embeddings()
Expand Down