diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index b7de4854784b..4eebe88952db 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -18,6 +18,7 @@ trainer: accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models gradient_clip_val: 1.0 benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually exp_manager: explicit_log_dir: null @@ -47,7 +48,7 @@ model: global_batch_size: 8 # will use more micro batches to reach global batch size tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism - resume_from_checkpoint: null # manually set the checkpoint file to load from + virtual_pipeline_model_parallel_size: null # interleaved pipeline # model architecture encoder_seq_length: 512 @@ -92,6 +93,7 @@ model: # miscellaneous seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from use_cpu_initialization: False # Init weights on the CPU (slow for large models) onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 7e66c3096f33..e0f73d00993e 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -171,6 +171,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/examples/nlp/language_modeling/megatron_t5_eval.py b/examples/nlp/language_modeling/megatron_t5_eval.py index 0c205ab65ad0..56b46f96a895 100644 --- a/examples/nlp/language_modeling/megatron_t5_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_eval.py @@ -70,6 +70,7 @@ def main(): app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py b/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py index 812eb51975d3..a01c9b15d195 100644 --- a/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py @@ -56,6 +56,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py b/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py index 1430b8b6da03..6afc5a505917 100644 --- a/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py @@ -57,6 +57,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py b/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py index 21a1b926f8c6..d150353475d5 100644 --- a/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py @@ -57,6 +57,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py index a8d87f71dcbe..7bee2f562dc3 100644 --- a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py +++ b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py @@ -62,6 +62,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index f613384eaa76..8091567ead62 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -33,6 +33,7 @@ from nemo.collections.nlp.parts.nlp_overrides import GradScaler from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler from nemo.utils import AppState, logging +from nemo.utils.get_rank import is_global_rank_zero try: from apex.transformer import parallel_state @@ -87,6 +88,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): local_rank=trainer.local_rank, tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), + virtual_pipeline_model_parallel_size=cfg.get('virtual_pipeline_model_parallel_size', None), pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), @@ -389,3 +391,17 @@ def _validate_config(self): logging.info("Gradient accumulation fusion can only be used with megatron amp O2 mixed precision.") with open_dict(self.cfg): self.cfg.gradient_accumulation_fusion = False + + def is_data_parallel_rank_zero(self): + if is_global_rank_zero(): + return True + else: + try: + data_parallel_rank = parallel_state.get_data_parallel_rank() + except: + data_parallel_rank = None + + if data_parallel_rank is not None and data_parallel_rank == 0: + return True + else: + return False diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2079b2a49a11..65f6bc281fc3 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -54,6 +54,9 @@ from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( forward_backward_pipelining_without_interleaving, ) + from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( + _forward_backward_pipelining_with_interleaving, + ) from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining HAVE_APEX = True @@ -77,29 +80,46 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._validate_trainer() - # TODO: Not sure how to use lists of modules with PTL. - # This means we can only use pipeline parallelism without the interleaved schedule. - self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0] - - # We don't need to call it explicitly? Since it is a pytorch lightning hook function - # self.setup_optimizer_param_groups() - self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False) self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' + if not self.megatron_amp_o2 and self.cfg.get('virtual_pipeline_model_parallel_size', None): + raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2') + if self.with_distributed_adam and not self.megatron_amp_o2: raise ValueError( "Distributed optimizers require O2. Please set megatron_amp_O2 to True in the model config." ) + # build_model returns a list of modules which are used for interleaved pipeline parallelism + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + ) + + # if we're not using interleaved, then self.model is a module. + if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: + self.model = self.model[0] + if self.megatron_amp_o2: if not self.with_distributed_adam: # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type - self.model.cuda(torch.cuda.current_device()) + if isinstance(self.model, list): + for module in self.model: + module.cuda(torch.cuda.current_device()) + else: + self.model.cuda(torch.cuda.current_device()) # Model wrapper to convert both model and inputs to half precision - self.model = Float16Module(module=self.model, precision=cfg.precision) + if isinstance(self.model, list): + converted_model = [] + for module in self.model: + converted_model.append(Float16Module(module=module, precision=cfg.precision)) + self.model = converted_model + else: + self.model = Float16Module(module=self.model, precision=cfg.precision) if self.trainer.precision == 32: self.autocast_dtype = torch.float @@ -113,21 +133,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # configuration used for inference self._inference_config = None - # At pipeline-parallel training, set the pipeline stage that the current GPU belongs to skip loading inputs - # Intemediate stage: doesn't need any inputs - # Fist pipeline stage: needs only 'tokens' and 'position_ids' - # Last pipeline stage: needs only 'labels' and 'loss_mask' - self._is_first_pipe_stage = False - self._is_intermediate_pipe_stage = False - self._is_last_pipe_stage = False - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if parallel_state.is_pipeline_first_stage(): - self._is_first_pipe_stage = True - elif parallel_state.is_pipeline_last_stage(): - self._is_last_pipe_stage = True - else: - self._is_intermediate_pipe_stage = True - def set_inference_config(self, inference_config): self._inference_config = inference_config @@ -172,9 +177,13 @@ def model_provider_func(self, pre_process, post_process): def setup_optimizer_param_groups(self): """ModelPT override. Optimizer will get self._optimizer_param_groups""" if self.cfg.get('do_layer_norm_weight_decay', False): - self._optimizer_param_groups = get_all_params_for_weight_decay_optimization([self.model]) + if isinstance(self.model, list): + self._optimizer_param_groups = get_all_params_for_weight_decay_optimization(self.model) + else: + self._optimizer_param_groups = get_all_params_for_weight_decay_optimization([self.model]) + else: - self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.model]) + self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) def setup_optimization( self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, @@ -190,6 +199,17 @@ def forward(self, tokens, text_position_ids, attention_mask, labels): output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels) return output_tensor + def _get_fwd_bwd_function(self): + fwd_bwd_function = None + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: + fwd_bwd_function = _forward_backward_pipelining_with_interleaving + else: + fwd_bwd_function = forward_backward_pipelining_without_interleaving + else: + fwd_bwd_function = forward_backward_no_pipelining + return fwd_bwd_function + def training_step(self, batch, batch_idx): """ Our dataloaders produce a micro-batch and then we fetch @@ -203,46 +223,44 @@ def training_step(self, batch, batch_idx): # we zero grads here because we also call backward in the apex fwd/bwd functions self._optimizer.zero_grad() - if self._is_intermediate_pipe_stage: + if parallel_state.is_pipeline_first_stage(ignore_virtual=True) or parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # we prepare the micro batches for the apex fwd/bwd function + batch_for_pipeline = self.process_global_batch(batch) + else: # The intermediate pipeline stages do not need any inputs from data loader # GPT3 uses decoder with AttnMask:causal, thus doesn't need attention_mask batch_for_pipeline = None - else: - # we prepare the micro batches for the apex fwd/bwd function - batch_for_pipeline = self.process_global_batch(batch) + tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.model, - forward_only=False, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), - ) - else: - # no pipeline parallelism so we reduce grads asynchronously if not using sequence parallelism - if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): - if self.with_distributed_adam: - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) - else: - custom_sync_context_handler = self._optimizer.no_sync + # determine if we can use async grad all reduce + custom_sync_context_handler = None + if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): + if self.with_distributed_adam: + custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.model, - forward_only=False, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - ) + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce for O1/autocast mixed precision training + custom_sync_context_handler = None + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = self._get_fwd_bwd_function() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + batch=batch_for_pipeline, + model=self.model, + forward_only=False, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + custom_sync_context_handler=custom_sync_context_handler, + sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), + ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: @@ -311,19 +329,28 @@ def optimizer_zero_grad(self, *args, **kwargs): """ return - def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 - """ - grads = [] - for param in self.model.parameters(): + def _append_module_grads(self, module, grads): + for param in module.parameters(): if getattr(param, 'sequence_parallel_enabled', False): if self.megatron_amp_o2: grad = param.main_grad else: grad = param.grad grads.append(grad.data) + + def allreduce_sequence_parallel_gradients(self): + """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """ + + grads = [] + if isinstance(self.model, list): + for module in self.model: + self._append_module_grads(module, grads) + else: + self._append_module_grads(self.model, grads) + coalesced = torch._utils._flatten_dense_tensors(grads) torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): @@ -337,10 +364,21 @@ def allreduce_first_last_embeddings(self): # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( - parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage() + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) ): - if self.model.share_token_embeddings: - word_embeddings_weight = self.model.word_embeddings_weight() + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + if isinstance(self.model, list): + module = self.model[0] # only the first virtual rank has the embeddings + else: + module = self.model + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + if isinstance(self.model, list): + module = self.model[-1] # only the last virtual rank has the embeddings + else: + module = self.model + if module.share_token_embeddings: + word_embeddings_weight = module.word_embeddings_weight() if self.megatron_amp_o2: # O2 recipe stores a "main" copy of weights and grads grad = word_embeddings_weight.main_grad @@ -356,21 +394,20 @@ def fwd_output_and_loss_func(batch, model): attention_mask = attention_mask[0:1] else: # GPT3 uses only causal mask, which doesn't need attention mask - if self._is_first_pipe_stage: + if parallel_state.is_pipeline_first_stage(): # Fist pipeline stage needs only the tokens and position_ids tokens = batch[0].cuda(non_blocking=True) position_ids = batch[4].cuda(non_blocking=True) labels, loss_mask, attention_mask = None, None, None - elif self._is_intermediate_pipe_stage: - # Intermediate pipeline stage doesn't need any inputs - tokens, labels, loss_mask, attention_mask, position_ids = None, None, None, None, None - elif self._is_last_pipe_stage: + elif parallel_state.is_pipeline_last_stage(): # Last pipeline stage needs only the labels and loss_mask labels = batch[1].cuda(non_blocking=True) loss_mask = batch[2].cuda(non_blocking=True) tokens, attention_mask, position_ids = None, None, None else: - assert False + # Intermediate pipeline stage doesn't need any inputs + tokens, labels, loss_mask, attention_mask, position_ids = None, None, None, None, None + output_tensor = model(tokens, position_ids, attention_mask, labels) def loss_func(output_tensor): @@ -423,26 +460,21 @@ def validation_step(self, batch, batch_idx): batch_for_pipeline = self.process_global_batch(batch) tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.model, - forward_only=True, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), - ) - else: - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.model, - forward_only=True, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - ) + # run forward passes for an entire global batch + # we do this inside validation_step to support pipeline parallelism + fwd_bwd_function = self._get_fwd_bwd_function() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + batch=batch_for_pipeline, + model=self.model, + forward_only=True, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), + ) + # only the last stage of the pipeline returns losses if losses_reduced_per_micro_batch: # average loss across micro batches loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] @@ -455,7 +487,8 @@ def validation_step(self, batch, batch_idx): return loss_mean def validation_epoch_end(self, outputs): - if parallel_state.is_pipeline_last_stage(): + + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): # only the last pipeline parallel stages return loss averaged_loss = torch.stack(outputs).mean() else: @@ -579,6 +612,41 @@ def setup(self, stage=None): Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # substract the embedding weights on the last virtual stage + num_word_embedding_parameters = sum([p.nelement() for p in self.model[-1].word_embeddings_weight()]) + num_parameters_on_device -= num_word_embedding_parameters + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # substract the embedding weights on the last stage + num_word_embedding_parameters = sum([p.nelement() for p in self.model.word_embeddings_weight()]) + + num_parameters_on_device -= num_word_embedding_parameters + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda() + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path if resume_checkpoint_path: init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) @@ -599,7 +667,13 @@ def setup(self, stage=None): # when using pipeline model parallel the final stage need to initialize word embeddings if parallel_state.get_pipeline_model_parallel_world_size() > 1: - self.model.sync_initial_word_embeddings() + if isinstance(self.model, list): + for i, module in enumerate(self.model): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + module.sync_initial_word_embeddings() + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + else: + self.model.sync_initial_word_embeddings() def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): @@ -730,3 +804,23 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: ) ) return result + + def on_save_checkpoint(self, checkpoint) -> None: + """LightningModule hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-save-checkpoint + """ + if isinstance(self.model, list): + for i in range(len(self.model)): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + checkpoint[f'model{i}'] = self.model[i].module.state_dict_for_save_checkpoint() + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + + def on_load_checkpoint(self, checkpoint) -> None: + """LightningModule hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint + """ + if isinstance(self.model, list): + for i in range(len(self.model)): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(0) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 206887f64a2e..e64f8760c3d1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -25,6 +25,7 @@ from apex.transformer.parallel_state import ( get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank, + set_virtual_pipeline_model_parallel_rank, set_pipeline_model_parallel_split_rank, set_pipeline_model_parallel_world_size, set_tensor_model_parallel_rank, @@ -32,12 +33,21 @@ ) from apex.transformer.microbatches import ConstantNumMicroBatches from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator - from apex.transformer.utils import ensure_divisibility HAVE_APEX = True except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + # TODO: remove when apex is updated + from apex.transformer.parallel_state import set_virtual_pipeline_model_parallel_world_size + + HAVE_INTERLEAVED = True + +except: + + HAVE_INTERLEAVED = False + def initialize_model_parallel_for_nemo( world_size, @@ -45,6 +55,7 @@ def initialize_model_parallel_for_nemo( local_rank, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, micro_batch_size=None, global_batch_size=None, @@ -52,6 +63,9 @@ def initialize_model_parallel_for_nemo( apex_transformer_log_level=30, ): + if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED: + raise ValueError("set_virtual_pipeline_model_parallel_world_size is needed in Apex for interleaved.") + # updating NeMo globals app_state = AppState() app_state.global_rank = global_rank @@ -59,17 +73,20 @@ def initialize_model_parallel_for_nemo( app_state.local_rank = local_rank app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size + app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=world_size, rank=global_rank, tensor_model_parallel_size_=tensor_model_parallel_size, pipeline_model_parallel_size_=pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, ) @@ -78,6 +95,9 @@ def initialize_model_parallel_for_nemo( set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank) set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) + if HAVE_INTERLEAVED: + set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size) + set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank) set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) @@ -178,18 +198,17 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size - ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size) + assert ( + world_size % tensor_model_parallel_size * pipeline_model_parallel_size == 0 + ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size}' data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size - # TODO: virtual pipeline model parallelism is not yet implemented in NeMo. This is needed for interleaved pipelining. - # if virtual_pipeline_model_parallel_size_ is not None: - # global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - # global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - # _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - # _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ + virtual_pipeline_model_parallel_rank = None + if virtual_pipeline_model_parallel_size_ is not None: + virtual_pipeline_model_parallel_rank = 0 # Build the data-parallel groups. all_data_parallel_group_ranks = [] @@ -272,4 +291,5 @@ def fake_initialize_model_parallel( model_parallel_size, data_parallel_size, pipeline_model_parallel_split_rank_, + virtual_pipeline_model_parallel_rank, ) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index dab2eba1dcd2..8aa379629657 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -187,6 +187,11 @@ def sync_initial_position_embeddings(self): position_embeddings = self.position_embeddings_weight() torch.distributed.all_reduce(position_embeddings.data, group=parallel_state.get_position_embedding_group()) + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """Use this function to override the state dict for + saving checkpoints.""" + return self.state_dict(destination, prefix, keep_vars) + def sync_initial_encoder_relative_position_embeddings(self): # Ensure that all encoder RPE stages have the same weights. if parallel_state.is_rank_in_encoder_relative_position_embedding_group(): diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 27b222dfd9d0..2e7c8ec05c4d 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1763,7 +1763,9 @@ def build_layer(layer_number): assert num_layers % parallel_state.get_virtual_pipeline_model_parallel_world_size() == 0, ( 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size' ) - assert self.model_type != ModelType.encoder_or_decoder + + assert self.model_type.value != 2, f'virtual pipeline parallel currently only supported for GPT' + # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = self.num_layers // parallel_state.get_virtual_pipeline_model_parallel_world_size() diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 36f7e91e1a41..a693fc6fa993 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -146,6 +146,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: tensor_model_parallel_size_=app_state.tensor_model_parallel_size, pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank, + virtual_pipeline_model_parallel_size_=app_state.virtual_pipeline_model_parallel_size, ) # assert that fake tp and pp rank match after model parallel init diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index f886008cc3d4..07af7ec62521 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -47,6 +47,7 @@ def __init__(self): self._tensor_model_parallel_size = None self._tensor_model_parallel_group = None self._pipeline_model_parallel_size = None + self._virtual_pipeline_model_parallel_size = None self._pipeline_model_parallel_group = None self._pipeline_model_parallel_split_rank = None self._is_megatron_initialized = False @@ -153,6 +154,22 @@ def pipeline_model_parallel_size(self, size): """ self._pipeline_model_parallel_size = size + @property + def virtual_pipeline_model_parallel_size(self): + """ Property returns the number of GPUs in each model parallel group. + Returns: + Number of GPUs in each model parallel group. + """ + return self._virtual_pipeline_model_parallel_size + + @virtual_pipeline_model_parallel_size.setter + def virtual_pipeline_model_parallel_size(self, size): + """ Property sets the size of the virtual pipeline parallel model. + Args: + size (int): Number of modules in each pipeline parallel model. + """ + self._virtual_pipeline_model_parallel_size = size + @property def data_parallel_size(self): """ Property returns the number of GPUs in each data parallel group. @@ -203,52 +220,68 @@ def global_rank(self, rank): @property def tensor_model_parallel_rank(self): - """ Property returns the model parallel rank. + """ Property returns the tensor model parallel rank. Returns: - Model parallel rank. + Tensor model parallel rank. """ return self._tensor_model_parallel_rank @tensor_model_parallel_rank.setter def tensor_model_parallel_rank(self, rank): - """ Property sets the model parallel rank. + """ Property sets the tensor model parallel rank. Args: - rank (int): Model parallel rank. + rank (int): Tensor model parallel rank. """ self._tensor_model_parallel_rank = rank @property def tensor_model_parallel_group(self): - """ Property returns the model parallel group. + """ Property returns the tensor model parallel group. Returns: - Model parallel group. + Tensor model parallel group. """ return self._tensor_model_parallel_group @tensor_model_parallel_group.setter def tensor_model_parallel_group(self, group): - """ Property sets the model parallel group. + """ Property sets the tensor model parallel group. Args: - group: Model parallel group. + group: Tensor model parallel group. """ self._tensor_model_parallel_group = group @property def pipeline_model_parallel_rank(self): - """ Property returns the model parallel rank. + """ Property returns the pipeline model parallel rank. Returns: - Model parallel rank. + Pipeline model parallel rank. """ return self._pipeline_model_parallel_rank @pipeline_model_parallel_rank.setter def pipeline_model_parallel_rank(self, rank): - """ Property sets the model parallel rank. + """ Property sets the pipeline model parallel rank. Args: - rank (int): Model parallel rank. + rank (int): Pipeline model parallel rank. """ self._pipeline_model_parallel_rank = rank + @property + def virtual_pipeline_model_parallel_rank(self): + """ Property returns the virtual pipeline parallel rank. + Returns: + Model parallel rank. + """ + return self._virtual_pipeline_model_parallel_rank + + @virtual_pipeline_model_parallel_rank.setter + def virtual_pipeline_model_parallel_rank(self, rank): + """ Property sets the virtual pipeline parallel rank. + Args: + rank (int): Virtual pipeline parallel rank. + """ + self._virtual_pipeline_model_parallel_rank = rank + @property def pipeline_model_parallel_split_rank(self): """ Property returns the rank at which Encoder and Decoder are split into different pipelines for Megatrron Encoder-Decoder models. @@ -267,17 +300,17 @@ def pipeline_model_parallel_split_rank(self, rank): @property def pipeline_model_parallel_group(self): - """ Property returns the model parallel group. + """ Property returns the pipeline model parallel group. Returns: - Model parallel group. + Pipeline model parallel group. """ return self._pipeline_model_parallel_group @pipeline_model_parallel_group.setter def pipeline_model_parallel_group(self, group): - """ Property sets the model parallel group. + """ Property sets the pipeline model parallel group. Args: - group: Model parallel group. + group: Pipeline model parallel group. """ self._pipeline_model_parallel_group = group diff --git a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py index 59cd1e694b13..af9a01fa3ece 100644 --- a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py @@ -72,6 +72,7 @@ def main(cfg) -> None: app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank,