diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index f286bb9a8adf..17ffc01fb7f4 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -1330,7 +1330,7 @@ def get_samples_mapping( ) torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( torch.distributed.get_world_size() diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index e86eb73add74..b7fec4f38e1e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -313,6 +313,7 @@ def __init__( self.indexed_dataset = indexed_dataset self.drop_last = drop_last self.seq_length = seq_length + self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True) # Checks assert np.min(documents) >= 0 @@ -433,13 +434,21 @@ def __getitem__(self, idx): logging.debug('Got negative index. Masking loss from this sample') loss_mask = torch.zeros_like(loss_mask) - return { - 'tokens': tokens, - 'labels': labels, - 'attention_mask': attention_mask, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } + if self.get_attention_mask_from_fusion: + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + return { + 'tokens': tokens, + 'labels': labels, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } class MockGPTDataset(Dataset): @@ -457,6 +466,7 @@ def __init__( self.vocab_size = tokenizer.vocab_size self.length = num_samples self.seed = seed + self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True) self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0) self.attention_mask = self.attention_mask < 0.5 @@ -476,13 +486,21 @@ def __getitem__(self, idx): tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) - return { - 'tokens': tokens, - 'labels': labels, - 'attention_mask': self.attention_mask, - 'loss_mask': self.loss_mask, - 'position_ids': self.position_ids, - } + if self.get_attention_mask_from_fusion: + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': self.loss_mask, + 'position_ids': self.position_ids, + } + else: + return { + 'tokens': tokens, + 'labels': labels, + 'attention_mask': self.attention_mask, + 'loss_mask': self.loss_mask, + 'position_ids': self.position_ids, + } @torch.no_grad() @@ -674,7 +692,7 @@ def _build_index_mappings( torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( torch.distributed.get_world_size() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 30e142a1afa6..044842c39941 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -169,6 +169,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), virtual_pipeline_model_parallel_size=vp_size, pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), + context_parallel_size=cfg.get('context_parallel_size', 1), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), rampup_batch_size=cfg.get('rampup_batch_size', None), @@ -231,6 +232,27 @@ def setup_transformer_engine_tp_groups(self): tp_group = parallel_state.get_tensor_model_parallel_group() child.set_tensor_parallel_group(tp_group) + def setup_transformer_engine_cp_groups(self): + """ This should be called after context parallel groups have been initialized + and only needs to be called when using Transformer Engine. + """ + cp_stream = torch.cuda.Stream() + + for module in self.get_model_module_list(): + """Set context parallel running + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + """ + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(module.modules()): + if index == 0: + continue + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) + def _wrap_model_for_O2(self): """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. Args: @@ -556,8 +578,10 @@ def allreduce_gradients(self): bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = torch._utils._flatten_dense_tensors(grads) - coalesced /= parallel_state.get_data_parallel_world_size() - torch.distributed.all_reduce(coalesced, group=parallel_state.get_data_parallel_group()) + coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=True) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=True) + ) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) @@ -633,7 +657,6 @@ def setup_optimization( ): optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() if self.with_distributed_adam: - # Allocate contiguous buffer to avoid extra copies optim_kwargs['contiguous_grad_buffer'] = True diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index b7e5fc08b1f8..332b76ee32be 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -278,7 +278,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Convert the global-batch-based profile index to micro-batch index if hasattr(self, '_nsys_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) - data_parallel_world_size = trainer.world_size // mp_size + cp_size = cfg.get('context_parallel_size', 1) + data_parallel_world_size = trainer.world_size // (mp_size * cp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps @@ -553,7 +554,9 @@ def initialize_ub_func(self): ) input_shape = [ - self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), + self.cfg.get('encoder_seq_length') + * self.cfg.get('micro_batch_size') + // self.cfg.get('context_parallel_size', 1), self.cfg.get('hidden_size'), ] @@ -834,6 +837,32 @@ def __next__(self): # TODO @tmoon: Use once available in Megatron-LM # return DataIteratorList(iters) + def get_batch_on_this_context_parallel_rank(self, batch): + cp_size = self.cfg.get('context_parallel_size', 1) + num_valid_tokens_in_ub = None + if 'loss_mask' in batch and batch['loss_mask'] is not None: + num_valid_tokens_in_ub = batch['loss_mask'].sum() + + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + for key, val in batch.items(): + if val is not None: + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val + + batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub + + return batch + def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): @@ -852,15 +881,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): required_keys.update(('labels', 'loss_mask')) - if self.get_attention_mask_from_fusion: + if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: required_keys.remove('attention_mask') batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + batch = self.get_batch_on_this_context_parallel_rank(batch) + # Model forward pass forward_args = { 'input_ids': batch['tokens'], 'position_ids': batch['position_ids'], - 'attention_mask': batch['attention_mask'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], 'labels': batch['labels'], 'loss_mask': batch['loss_mask'], } @@ -885,9 +916,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) - loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + cp_size = self.cfg.get('context_parallel_size', 1) if validation_step and not self.cfg.data.get('validation_drop_last', True): - num_valid_tokens_in_ub = batch['loss_mask'].sum() + num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) @@ -904,10 +936,10 @@ def loss_func(output_tensor): torch.distributed.all_reduce( loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() ) - return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub, {'avg': reduced_loss} + return loss_for_ub * cp_size, {'avg': reduced_loss} return output_tensor, loss_func @@ -1007,10 +1039,11 @@ def on_validation_epoch_end(self): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if self.loss_broadcast_src_rank is None: dp_size = parallel_state.get_data_parallel_world_size() + cp_size = parallel_state.get_context_parallel_world_size() tp_size = parallel_state.get_tensor_model_parallel_world_size() pp_size = parallel_state.get_pipeline_model_parallel_world_size() - rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) - last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) + rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size) + last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1) self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group torch.distributed.broadcast( averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), @@ -1029,11 +1062,14 @@ def on_test_epoch_end(self): logging.info(f'test_loss: {averaged_loss[0]}') self.test_step_outputs.clear() # free memory - def loss_func(self, loss_mask, output_tensor): + def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size > 1: + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) return loss def build_train_valid_test_datasets(self): @@ -1185,6 +1221,7 @@ def setup(self, stage=None): if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False): self.setup_transformer_engine_tp_groups() + self.setup_transformer_engine_cp_groups() def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): @@ -1243,6 +1280,7 @@ def dummy(): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() + self.setup_transformer_engine_cp_groups() # set the default sampling params if it is None. # default do greedy sampling diff --git a/nemo/collections/nlp/modules/common/megatron/build_model.py b/nemo/collections/nlp/modules/common/megatron/build_model.py index 929093405fce..2749eae846cd 100644 --- a/nemo/collections/nlp/modules/common/megatron/build_model.py +++ b/nemo/collections/nlp/modules/common/megatron/build_model.py @@ -151,7 +151,10 @@ def build_model( i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( - model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), + model_module, + device_ids=[i], + output_device=i, + process_group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) for model_module in model ] diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index ea8da795de31..bbeeade2d8c5 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -535,6 +535,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = config.sequence_parallel + self.context_parallel = parallel_state.get_context_parallel_world_size() > 1 if kv_channels is None: assert ( @@ -722,6 +723,19 @@ def set_input_tensor(self, input_tensor): self.encoder.set_input_tensor(input_tensor[0]) + def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim): + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=position_embedding.device) + position_embedding = position_embedding.view( + *position_embedding.shape[:seq_dim], 2 * cp_size, -1, *position_embedding.shape[(seq_dim + 1) :] + ) + position_embedding = position_embedding.index_select(seq_dim, cp_idx) + position_embedding = position_embedding.view( + *position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim + 2) :] + ) + return position_embedding + def forward( self, enc_input_ids, @@ -775,10 +789,16 @@ def forward( else: enc_seq_length = encoder_input.size(0) + if self.context_parallel: + enc_seq_length = enc_seq_length * parallel_state.get_context_parallel_world_size() + rotary_pos_emb = None encoder_self_attention_relative_position_bias = None if self.position_embedding_type == 'rope': rotary_pos_emb = self.rotary_pos_emb(enc_seq_length) + + if self.context_parallel: + rotary_pos_emb = self.get_position_embedding_on_this_context_parallel_rank(rotary_pos_emb, 0) elif ( self.position_embedding_type == 'alibi' or self.position_embedding_type == 'sandwich' @@ -790,6 +810,11 @@ def forward( # causal attention bias: [1, head, 1, k] # non-causal attention bias: [1, head, q, k] + if self.context_parallel and encoder_self_attention_relative_position_bias.shape[-2] > 1: + encoder_self_attention_relative_position_bias = self.get_position_embedding_on_this_context_parallel_rank( + encoder_self_attention_relative_position_bias, 2 + ) + # encoder. if enc_hidden_states is None: encoder_output = self.encoder( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 7431bffad26c..013838e7688e 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -63,6 +63,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, + context_parallel_size=1, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -83,6 +84,7 @@ def initialize_model_parallel_for_nemo( app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + app_state.context_parallel_size = context_parallel_size app_state.use_fp8 = use_fp8 app_state.init_mpi_proc_group = init_mpi_proc_group ( @@ -99,6 +101,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size_=pipeline_model_parallel_size, virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, + context_parallel_size_=context_parallel_size, ) # update apex.transformer globals @@ -176,6 +179,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size_, pipeline_model_parallel_split_rank_=None, virtual_pipeline_model_parallel_size_=None, + context_parallel_size_=1, ): """ Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized. @@ -186,6 +190,7 @@ def fake_initialize_model_parallel( Arguments: tensor_model_parallel_size: number of GPUs used to parallelize model tensor. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. + context_parallel_size: number of GPUs used to parallelize tokens of each input. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize @@ -208,11 +213,14 @@ def fake_initialize_model_parallel( tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size + context_parallel_size = min(context_parallel_size_, world_size) assert ( - world_size % tensor_model_parallel_size * pipeline_model_parallel_size == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size}' - data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) + world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 + ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' + data_parallel_size = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + ) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size @@ -222,25 +230,58 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_rank = 0 # Build the data-parallel groups. - all_data_parallel_group_ranks = [] + all_data_parallel_group_ranks_with_cp = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) + for j in range(context_parallel_size * tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size) if rank in ranks: data_parallel_group = list(ranks) - logging.info(f'Rank {rank} has data parallel group: {data_parallel_group}') + logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') + for j in range(tensor_model_parallel_size): + ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) + if rank in ranks_with_cp: + data_parallel_group_with_cp = list(ranks_with_cp) + logging.info( + f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}' + ) data_parallel_rank = data_parallel_group.index(rank) - logging.info(f'All data parallel group ranks: {all_data_parallel_group_ranks}') + logging.info( + f'All data parallel group ranks with context parallel combined: {all_data_parallel_group_ranks_with_cp}' + ) logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') + # Build the context-parallel groups. + all_context_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + for j in range(data_parallel_size): + start_rank = ( + i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size + ) + end_rank = ( + i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size + ) + for k in range(tensor_model_parallel_size): + ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) + all_context_parallel_group_ranks.append(list(ranks)) + if rank in ranks: + context_parallel_group = list(ranks) + logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}') + + context_parallel_rank = context_parallel_group.index(rank) + logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}') + logging.info(f'Ranks {rank} has context parallel rank: {context_parallel_rank}') + # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for i in range(data_parallel_size): - ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] + for i in range(data_parallel_size * context_parallel_size): + ranks = [ + data_parallel_group_ranks_with_cp[i] + for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp + ] all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 83b0c6c0f2ac..ca8c0ecafefd 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1463,7 +1463,7 @@ def forward( # fp8_autocast will not do anything if TE or FP8 isn't used fp8_group = None if self.fp8 and parallel_state.model_parallel_is_initialized(): - fp8_group = parallel_state.get_amax_reduction_group() + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) if HAVE_TE: # if TE is installed but fp8 is not available then this will do nothing diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 293b8a3f5bce..5075863c3dbb 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -114,6 +114,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None) pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + context_parallel_size=app_state.context_parallel_size, nccl_communicator_config_path=nccl_communicator_config_path, use_sharp=sharp, ) @@ -223,7 +224,7 @@ def configure_ddp(self): # device_ids = self.determine_ddp_device_ids() self._model = DistributedDataParallel( _LightningModuleWrapperBase(self.model), - process_group=parallel_state.get_data_parallel_group(), + process_group=parallel_state.get_data_parallel_group(with_context_parallel=True), **self._ddp_kwargs, ) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index d7bc049c1808..a7baf67b9057 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -73,7 +73,7 @@ def __init__( # Initialize process groups if 'process_group' not in kwargs and not parallel_state.is_unitialized(): - kwargs['process_group'] = parallel_state.get_data_parallel_group() + kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True) if disable_distributed_parameters: world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index bddefc03e6d4..680b82ed7201 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -107,8 +107,8 @@ def zero(self): def allreduce_buffer(self): """Synchronous buffer data allreduce """ - self.data.div_(get_data_parallel_world_size()) - torch.distributed.all_reduce(self.data, group=get_data_parallel_group()) + self.data.div_(get_data_parallel_world_size(with_context_parallel=True)) + torch.distributed.all_reduce(self.data, group=get_data_parallel_group(with_context_parallel=True)) def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the @@ -205,8 +205,10 @@ def __init__( # used with tensor parallel only (no pipeline parallelism) # be careful, weight update cannot start until all async grad AR works are done - self._async_grad_allreduce = async_grad_allreduce and get_data_parallel_world_size() > 1 - self._grad_divisor = 1 / get_data_parallel_world_size() + self._async_grad_allreduce = ( + async_grad_allreduce and get_data_parallel_world_size(with_context_parallel=True) > 1 + ) + self._grad_divisor = 1 / get_data_parallel_world_size(with_context_parallel=True) if self._async_grad_allreduce: # use @no_sync to disable backward grad sync during gradient accumulation @@ -341,27 +343,29 @@ def param_hook(*unused): if self._grad_div_ar_fusion: torch.distributed.all_reduce( allreduce_tensor, - group=get_data_parallel_group(), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - allreduce_tensor.div_(get_data_parallel_world_size()) + allreduce_tensor.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( - allreduce_tensor, group=get_data_parallel_group(), async_op=True, + allreduce_tensor, + group=get_data_parallel_group(with_context_parallel=True), + async_op=True, ) else: if self._grad_div_ar_fusion: torch.distributed.all_reduce( main_param.grad, - group=get_data_parallel_group(), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - main_param.grad.div_(get_data_parallel_world_size()) + main_param.grad.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( - main_param.grad, group=get_data_parallel_group(), async_op=True, + main_param.grad, group=get_data_parallel_group(with_context_parallel=True), async_op=True, ) return param_hook diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index d06e1ac32e36..eb6b6d91ba5e 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -55,6 +55,7 @@ def __init__(self): self._data_parallel_group = None self._megatron_checkpoint_version = None self._use_fp8 = False + self._context_parallel_size = None self._init_mpi_proc_gruop = False self._random_seed = None @@ -364,6 +365,22 @@ def use_fp8(self, use_fp8): """ self._use_fp8 = use_fp8 + @property + def context_parallel_size(self): + """ Property returns the number of GPUs in each context parallel group. + Returns: + Number of GPUs in each context parallel group. + """ + return self._context_parallel_size + + @context_parallel_size.setter + def context_parallel_size(self, size): + """ Property sets the number of GPUs in each context parallel group. + Args: + size (int): Number of GPUs in each context parallel group. + """ + self._context_parallel_size = size + @property def init_mpi_proc_group(self): """ Property sets the initialization of mpi process group.