From afce64ee72fab055986148c8b403744a0b74272c Mon Sep 17 00:00:00 2001 From: xren Date: Mon, 5 Jun 2023 18:40:59 -0700 Subject: [PATCH 01/47] make nemo recognize sequence_parallel_size Signed-off-by: xren --- .../language_modeling/megatron_base_model.py | 1 + .../modules/common/megatron/megatron_init.py | 27 +++++++++++++++++-- nemo/collections/nlp/parts/nlp_overrides.py | 1 + nemo/utils/app_state.py | 17 ++++++++++++ 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 1237491fa39c..8cbf1702ce07 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -119,6 +119,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), virtual_pipeline_model_parallel_size=cfg.get('virtual_pipeline_model_parallel_size', None), pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), + sequence_parallel_size=cfg.get('sequence_parallel_size', 1), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), rampup_batch_size=cfg.get('rampup_batch_size'), diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index e0551fad5d16..9b992ba1124c 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -63,6 +63,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, + sequence_parallel_size=1, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -82,6 +83,7 @@ def initialize_model_parallel_for_nemo( app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + app_state.sequence_parallel_size = sequence_parallel_size app_state.use_fp8 = use_fp8 ( app_state.tensor_model_parallel_rank, @@ -97,6 +99,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size_=pipeline_model_parallel_size, virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, + sequence_parallel_size_=sequence_parallel_size, ) # update apex.transformer globals @@ -174,6 +177,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size_, pipeline_model_parallel_split_rank_=None, virtual_pipeline_model_parallel_size_=None, + sequence_parallel_size_=1, ): """ Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized. @@ -184,6 +188,7 @@ def fake_initialize_model_parallel( Arguments: tensor_model_parallel_size: number of GPUs used to parallelize model tensor. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. + sequence_parallel_size: number of GPUs used to parallelize tokens of each input sequence. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize @@ -206,10 +211,11 @@ def fake_initialize_model_parallel( tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size + sequence_parallel_size = min(sequence_parallel_size_, world_size) assert ( - world_size % tensor_model_parallel_size * pipeline_model_parallel_size == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size}' + world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) == 0 + ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times sequence_parallel_size {sequence_parallel_size}' data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size @@ -235,6 +241,23 @@ def fake_initialize_model_parallel( logging.info(f'All data parallel group ranks: {all_data_parallel_group_ranks}') logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') + # Build the sequence-parallel groups. + all_sequence_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + for j in range(data_parallel_size // sequence_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * sequence_parallel_size + end_rank = i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * sequence_parallel_size + for k in range(tensor_model_parallel_size): + ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) + all_sequence_parallel_group_ranks.append(list(ranks)) + if rank in ranks: + sequence_parallel_group = list(ranks) + logging.info(f'Rank {rank} has sequence parallel group: {sequence_parallel_group}') + + sequence_parallel_rank = sequence_parallel_group.index(rank) + logging.info(f'All sequence parallel group ranks: {all_sequence_parallel_group_ranks}') + logging.info(f'Ranks {rank} has sequence parallel rank: {squence_parallel_rank}') + # Build the model-parallel groups. all_model_parallel_group_ranks = [] for i in range(data_parallel_size): diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5a0f028ddbe9..2ff01c5fcc33 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -168,6 +168,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + sequence_parallel_size=app_state.sequence_parallel_size, ) # assert that fake tp and pp rank match after model parallel init diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index c3ead0bff48f..3af7f9a92ac8 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -55,6 +55,7 @@ def __init__(self): self._data_parallel_group = None self._megatron_checkpoint_version = None self._use_fp8 = False + self._sequence_parallel_size = None self._random_seed = None @@ -363,6 +364,22 @@ def use_fp8(self, use_fp8): """ self._use_fp8 = use_fp8 + @property + def sequence_parallel_size(self): + """ Property returns the number of GPUs in each sequence parallel group. + Returns: + Number of GPUs in each sequence parallel group. + """ + return self._sequence_parallel_size + + @sequence_parallel_size.setter + def sequence_parallel_size(self, size): + """ Property sets the number of GPUs in each sequence parallel group. + Args: + size (int): Number of GPUs in each sequence parallel group. + """ + self._sequence_parallel_size = size + @property def random_seed(self): """ Property returns the random seed. From e31300099c72bbcd8acfabf23e8ee96b28dcdae7 Mon Sep 17 00:00:00 2001 From: xren Date: Mon, 5 Jun 2023 19:19:31 -0700 Subject: [PATCH 02/47] add helper functions to set up SP running in TE Signed-off-by: xren --- .../language_modeling/megatron_gpt_model.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9aadb6853190..e43377f16544 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1033,6 +1033,7 @@ def setup(self, stage=None): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() + self.setup_transformer_engine_sp_running() def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): @@ -1089,6 +1090,7 @@ def dummy(): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() + self.setup_transformer_engine_sp_running() # set the default sampling params if it is None. # default do greedy sampling @@ -1186,6 +1188,31 @@ def setup_transformer_engine_tp_groups(self): else: self._set_tp_groups(self.model) + def _set_sp_running(self, module): + """ Helper method to set sp running for transformer engine""" + + if self.cfg.get('transformer_engine', False): + logging.info(f'Setting up transformer engine modules for seequence parallelism.') + sp_stream = torch.cuda.Stream() + if self.cfg.get('megatron_amp_O2', 'False'): + # when using O2 additional module key is added that casts the weights + for layer in module.module.language_model.encoder.layers: + layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), sp_stream) + + else: + for layer in module.language_model.encoder.layers: + layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), sp_stream) + + def setup_transformer_engine_sp_running(self): + """ This should be called after model parallel groups have been initialized + and only needs to be called when using Transformer Engine. + """ + if isinstance(self.model, list): + for module in self.model: + self._set_sp_running(module) + else: + self._set_sp_running(self.model) + def on_save_checkpoint(self, checkpoint) -> None: """LightningModule hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-save-checkpoint From 55809554251c243a8e2cb2c930ff557beb35fe1f Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 7 Jun 2023 17:05:20 -0700 Subject: [PATCH 03/47] slice seq length for a specific rank Signed-off-by: Xiaowei Ren --- .../models/language_modeling/megatron_gpt_model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 094aaaa2021e..9eca6dbfbcbd 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -734,6 +734,18 @@ def __next__(self): # TODO @tmoon: Use once available in Megatron-LM # return DataIteratorList(iters) + def get_batch_on_this_sequence_parallel_rank(batch): + sequence_parallel_size = parallel_state.get_sequence_parallel_world_size() + if sequence_parallel_size > 1: + sequence_parallel_rank = parallel_state.get_sequence_parallel_rank() + for key, val in batch.items(): + if val is not None: + val = val.view(val.shape[0], 2*sequence_parallel_size, val.shape[1]//(2*sequence_parallel_size), *val.shape[2:]) + val = torch.cat([val[:, sequence_parallel_rank, ...], val[:, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=1) + batch[key] = val + + return batch + def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): @@ -754,6 +766,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.remove('attention_mask') batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + batch = self.get_batch_on_this_sequence_parallel_rank(batch) + # Model forward pass output_tensor = model( batch['tokens'], From ebd63238e1ecd9ff0514b5fbba13e4b541d1807d Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 7 Jun 2023 23:17:29 -0700 Subject: [PATCH 04/47] fix data_parallel_size calculation Signed-off-by: Xiaowei Ren --- .../nlp/modules/common/megatron/megatron_init.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 9b992ba1124c..6c22e8953aa5 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -216,7 +216,7 @@ def fake_initialize_model_parallel( assert ( world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) == 0 ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times sequence_parallel_size {sequence_parallel_size}' - data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) + data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size @@ -244,7 +244,7 @@ def fake_initialize_model_parallel( # Build the sequence-parallel groups. all_sequence_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): - for j in range(data_parallel_size // sequence_parallel_size): + for j in range(data_parallel_size): start_rank = i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * sequence_parallel_size end_rank = i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * sequence_parallel_size for k in range(tensor_model_parallel_size): @@ -256,11 +256,11 @@ def fake_initialize_model_parallel( sequence_parallel_rank = sequence_parallel_group.index(rank) logging.info(f'All sequence parallel group ranks: {all_sequence_parallel_group_ranks}') - logging.info(f'Ranks {rank} has sequence parallel rank: {squence_parallel_rank}') + logging.info(f'Ranks {rank} has sequence parallel rank: {sequence_parallel_rank}') # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for i in range(data_parallel_size): + for i in range(data_parallel_size * sequence_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] all_model_parallel_group_ranks.append(ranks) if rank in ranks: From 58cca3d37939603b3ca16446602de311cc6a9a46 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 8 Jun 2023 08:54:13 -0700 Subject: [PATCH 05/47] minor change Signed-off-by: Xiaowei Ren --- nemo/collections/nlp/modules/common/megatron/megatron_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 6c22e8953aa5..0da95c343ff4 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -237,7 +237,7 @@ def fake_initialize_model_parallel( data_parallel_group = list(ranks) logging.info(f'Rank {rank} has data parallel group: {data_parallel_group}') - data_parallel_rank = data_parallel_group.index(rank) + data_parallel_rank = data_parallel_group.index(rank) // sequence_parallel_size logging.info(f'All data parallel group ranks: {all_data_parallel_group_ranks}') logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') From 87f027a2826d6a42707140b8e1e55c030958c056 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 8 Jun 2023 14:35:17 -0700 Subject: [PATCH 06/47] add missing argument of self Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9eca6dbfbcbd..064f4654b697 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -734,7 +734,7 @@ def __next__(self): # TODO @tmoon: Use once available in Megatron-LM # return DataIteratorList(iters) - def get_batch_on_this_sequence_parallel_rank(batch): + def get_batch_on_this_sequence_parallel_rank(self, batch): sequence_parallel_size = parallel_state.get_sequence_parallel_world_size() if sequence_parallel_size > 1: sequence_parallel_rank = parallel_state.get_sequence_parallel_rank() From 9ebfcf7fae79c97ed2a24cf9c4563ab8619d6894 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 8 Jun 2023 16:09:18 -0700 Subject: [PATCH 07/47] pass sp_global_ranks to TE transformer layer Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 064f4654b697..9bdb43555a2f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1209,11 +1209,15 @@ def _set_sp_running(self, module): if self.cfg.get('megatron_amp_O2', 'False'): # when using O2 additional module key is added that casts the weights for layer in module.module.language_model.encoder.layers: - layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), sp_stream) + layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), + parallel_state.get_sequence_parallel_global_ranks(), + sp_stream) else: for layer in module.language_model.encoder.layers: - layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), sp_stream) + layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), + parallel_state.get_sequence_parallel_global_ranks(), + sp_stream) def setup_transformer_engine_sp_running(self): """ This should be called after model parallel groups have been initialized From 728fd43966560f5927b4bf07064ab9dbd5eee889 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 8 Jun 2023 18:25:31 -0700 Subject: [PATCH 08/47] fix nsys setting Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9bdb43555a2f..15c9067a5a0a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -270,7 +270,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Convert the global-batch-based profile index to micro-batch index if hasattr(self, '_nsys_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) - data_parallel_world_size = trainer.world_size // mp_size + sp_size = cfg.get('sequence_parallel_size', 1) + data_parallel_world_size = trainer.world_size // (mp_size * sp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps From 66615e822baf45828e5e55bab937cd0a06ca827f Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 13 Jun 2023 14:19:47 -0700 Subject: [PATCH 09/47] fix seq_len calculation Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 15c9067a5a0a..64cf56020629 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -439,7 +439,7 @@ def forward(self, tokens, text_position_ids, attention_mask, labels): return output_tensor def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): - tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] + tensor_shape = [self.cfg.encoder_seq_length // self.cfg.get('sequence_parallel_size', 1), self.cfg.micro_batch_size, self.cfg.hidden_size] # handle asynchronous grad reduction no_sync_func = None @@ -736,7 +736,7 @@ def __next__(self): # return DataIteratorList(iters) def get_batch_on_this_sequence_parallel_rank(self, batch): - sequence_parallel_size = parallel_state.get_sequence_parallel_world_size() + sequence_parallel_size = self.cfg.get('sequence_parallel_size', 1) if sequence_parallel_size > 1: sequence_parallel_rank = parallel_state.get_sequence_parallel_rank() for key, val in batch.items(): From e1f5eb79f9ba9986b1dc6f2c1242d74082aec8d0 Mon Sep 17 00:00:00 2001 From: xren Date: Fri, 16 Jun 2023 19:25:42 -0700 Subject: [PATCH 10/47] fix attn_mask split across seq-length dim Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 64cf56020629..dbf1354a30cd 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -741,8 +741,12 @@ def get_batch_on_this_sequence_parallel_rank(self, batch): sequence_parallel_rank = parallel_state.get_sequence_parallel_rank() for key, val in batch.items(): if val is not None: - val = val.view(val.shape[0], 2*sequence_parallel_size, val.shape[1]//(2*sequence_parallel_size), *val.shape[2:]) - val = torch.cat([val[:, sequence_parallel_rank, ...], val[:, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=1) + if key == 'attention_mask': + val = val.view(val.shape[0:2], 2*sequence_parallel_size, val.shape[2]//(2*sequence_parallel_size), *val.shape[3:]) + val = torch.cat([val[:, :, sequence_parallel_rank, ...], val[:, :, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=2) + else: + val = val.view(val.shape[0], 2*sequence_parallel_size, val.shape[1]//(2*sequence_parallel_size), *val.shape[2:]) + val = torch.cat([val[:, sequence_parallel_rank, ...], val[:, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=1) batch[key] = val return batch From cf0c75cc06c10f8987de469b79a94dbf8722f5ce Mon Sep 17 00:00:00 2001 From: xren Date: Fri, 16 Jun 2023 20:01:00 -0700 Subject: [PATCH 11/47] code update of input split Signed-off-by: xren --- .../language_modeling/megatron_gpt_model.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index dbf1354a30cd..16bc1e56e626 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -736,17 +736,16 @@ def __next__(self): # return DataIteratorList(iters) def get_batch_on_this_sequence_parallel_rank(self, batch): - sequence_parallel_size = self.cfg.get('sequence_parallel_size', 1) - if sequence_parallel_size > 1: - sequence_parallel_rank = parallel_state.get_sequence_parallel_rank() + sp_size = self.cfg.get('sequence_parallel_size', 1) + if sp_size > 1: + sp_rank = parallel_state.get_sequence_parallel_rank() for key, val in batch.items(): if val is not None: - if key == 'attention_mask': - val = val.view(val.shape[0:2], 2*sequence_parallel_size, val.shape[2]//(2*sequence_parallel_size), *val.shape[3:]) - val = torch.cat([val[:, :, sequence_parallel_rank, ...], val[:, :, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=2) - else: - val = val.view(val.shape[0], 2*sequence_parallel_size, val.shape[1]//(2*sequence_parallel_size), *val.shape[2:]) - val = torch.cat([val[:, sequence_parallel_rank, ...], val[:, (2*sequence_parallel_size - sequence_parallel_rank - 1), ...]], dim=1) + seq_dim = 1 if key != 'attnetion_mask' else 2 + val = val.view(*val.shape[0:seq_dim], 2*sp_size, val.shape[seq_dim]//(2*sp_size), *val.shape[(seq_dim+1):]) + index = torch.tensor([sp_rank, (2*sp_size-sp_rank-1)], device=val.device) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val return batch From b57e2185b53ba7a3dc9410471a9bee02652b9193 Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 20 Jun 2023 23:57:49 -0700 Subject: [PATCH 12/47] fix loss calculation Signed-off-by: xren --- .../models/language_modeling/megatron_gpt_model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 16bc1e56e626..45e44d5a2398 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -739,14 +739,19 @@ def get_batch_on_this_sequence_parallel_rank(self, batch): sp_size = self.cfg.get('sequence_parallel_size', 1) if sp_size > 1: sp_rank = parallel_state.get_sequence_parallel_rank() + loss_mask_sum = None for key, val in batch.items(): if val is not None: + if key == 'loss_mask': + loss_mask_sum = val.sum() seq_dim = 1 if key != 'attnetion_mask' else 2 val = val.view(*val.shape[0:seq_dim], 2*sp_size, val.shape[seq_dim]//(2*sp_size), *val.shape[(seq_dim+1):]) index = torch.tensor([sp_rank, (2*sp_size-sp_rank-1)], device=val.device) val = val.index_select(seq_dim, index) val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val + if loss_mask_sum is not None: + batch['loss_mask_sum'] = loss_mask_sum return batch @@ -783,7 +788,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) - loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['loss_mask_sum'], output_tensor) if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['loss_mask'].sum() if loss_for_ub.isnan(): @@ -888,11 +893,12 @@ def test_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) logging.info(f'test_loss: {averaged_loss[0]}') - def loss_func(self, loss_mask, output_tensor): + def loss_func(self, loss_mask, loss_mask_sum, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + sp_size = self.cfg.get('sequence_parallel_size', 1) + loss = torch.sum(losses.view(-1) * loss_mask) / (loss_mask_sum / sp_size) # sequence level nll return loss def build_train_valid_test_datasets(self): From 69f4ae889efb98bd09bd810fa771619d9b81b648 Mon Sep 17 00:00:00 2001 From: xren Date: Wed, 21 Jun 2023 00:57:55 -0700 Subject: [PATCH 13/47] fix loss_mask_sum calculation Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 45e44d5a2398..2f57d6184979 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -737,21 +737,18 @@ def __next__(self): def get_batch_on_this_sequence_parallel_rank(self, batch): sp_size = self.cfg.get('sequence_parallel_size', 1) + loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() if sp_size > 1: sp_rank = parallel_state.get_sequence_parallel_rank() - loss_mask_sum = None for key, val in batch.items(): if val is not None: - if key == 'loss_mask': - loss_mask_sum = val.sum() seq_dim = 1 if key != 'attnetion_mask' else 2 val = val.view(*val.shape[0:seq_dim], 2*sp_size, val.shape[seq_dim]//(2*sp_size), *val.shape[(seq_dim+1):]) index = torch.tensor([sp_rank, (2*sp_size-sp_rank-1)], device=val.device) val = val.index_select(seq_dim, index) val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val - if loss_mask_sum is not None: - batch['loss_mask_sum'] = loss_mask_sum + batch['loss_mask_sum'] = loss_mask_sum return batch From a38dd9aef2ed8ebfbaf94cd4c84037f9d44dc2af Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 22 Jun 2023 14:35:09 -0700 Subject: [PATCH 14/47] fix losss calculation Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2f57d6184979..5976418d2efa 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -786,6 +786,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['loss_mask_sum'], output_tensor) + sp_size = self.cfg.get('sequence_parallel_size', 1) if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['loss_mask'].sum() if loss_for_ub.isnan(): @@ -807,7 +808,7 @@ def loss_func(output_tensor): return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub, {'avg': reduced_loss} + return loss_for_ub*sp_size, {'avg': reduced_loss} return output_tensor, loss_func @@ -894,8 +895,8 @@ def loss_func(self, loss_mask, loss_mask_sum, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here - sp_size = self.cfg.get('sequence_parallel_size', 1) - loss = torch.sum(losses.view(-1) * loss_mask) / (loss_mask_sum / sp_size) # sequence level nll + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum # sequence level nll + torch.distributed.all_reduce(loss, group=parallel_state.get_sequence_parallel_group()) return loss def build_train_valid_test_datasets(self): From 8ac42f1d56b911ab0f638d3694d17a0705bedab7 Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 22 Jun 2023 15:11:20 -0700 Subject: [PATCH 15/47] rename sequence_parallelism to context_parallelism Signed-off-by: xren --- .../language_modeling/megatron_base_model.py | 2 +- .../language_modeling/megatron_gpt_model.py | 58 +++++++++---------- .../modules/common/megatron/megatron_init.py | 42 +++++++------- nemo/collections/nlp/parts/nlp_overrides.py | 2 +- nemo/utils/app_state.py | 20 +++---- 5 files changed, 62 insertions(+), 62 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 15a0b5ba0451..f5e35275f441 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -127,7 +127,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), virtual_pipeline_model_parallel_size=cfg.get('virtual_pipeline_model_parallel_size', None), pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), - sequence_parallel_size=cfg.get('sequence_parallel_size', 1), + context_parallel_size=cfg.get('context_parallel_size', 1), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), rampup_batch_size=cfg.get('rampup_batch_size'), diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index c4cbf23ecb18..5ea0df702122 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -276,8 +276,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Convert the global-batch-based profile index to micro-batch index if hasattr(self, '_nsys_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) - sp_size = cfg.get('sequence_parallel_size', 1) - data_parallel_world_size = trainer.world_size // (mp_size * sp_size) + cp_size = cfg.get('context_parallel_size', 1) + data_parallel_world_size = trainer.world_size // (mp_size * cp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps @@ -447,7 +447,7 @@ def forward(self, tokens, text_position_ids, attention_mask, labels): return output_tensor def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): - tensor_shape = [self.cfg.encoder_seq_length // self.cfg.get('sequence_parallel_size', 1), self.cfg.micro_batch_size, self.cfg.hidden_size] + tensor_shape = [self.cfg.encoder_seq_length // self.cfg.get('context_parallel_size', 1), self.cfg.micro_batch_size, self.cfg.hidden_size] # handle asynchronous grad reduction no_sync_func = None @@ -748,16 +748,16 @@ def __next__(self): # TODO @tmoon: Use once available in Megatron-LM # return DataIteratorList(iters) - def get_batch_on_this_sequence_parallel_rank(self, batch): - sp_size = self.cfg.get('sequence_parallel_size', 1) + def get_batch_on_this_context_parallel_rank(self, batch): + cp_size = self.cfg.get('context_parallel_size', 1) loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() - if sp_size > 1: - sp_rank = parallel_state.get_sequence_parallel_rank() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): if val is not None: seq_dim = 1 if key != 'attnetion_mask' else 2 - val = val.view(*val.shape[0:seq_dim], 2*sp_size, val.shape[seq_dim]//(2*sp_size), *val.shape[(seq_dim+1):]) - index = torch.tensor([sp_rank, (2*sp_size-sp_rank-1)], device=val.device) + val = val.view(*val.shape[0:seq_dim], 2*cp_size, val.shape[seq_dim]//(2*cp_size), *val.shape[(seq_dim+1):]) + index = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=val.device) val = val.index_select(seq_dim, index) val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val @@ -785,7 +785,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.remove('attention_mask') batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} - batch = self.get_batch_on_this_sequence_parallel_rank(batch) + batch = self.get_batch_on_this_context_parallel_rank(batch) # Model forward pass output_tensor = model( @@ -799,7 +799,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['loss_mask_sum'], output_tensor) - sp_size = self.cfg.get('sequence_parallel_size', 1) + cp_size = self.cfg.get('context_parallel_size', 1) if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['loss_mask'].sum() if loss_for_ub.isnan(): @@ -821,7 +821,7 @@ def loss_func(output_tensor): return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub*sp_size, {'avg': reduced_loss} + return loss_for_ub*cp_size, {'avg': reduced_loss} return output_tensor, loss_func @@ -910,7 +910,7 @@ def loss_func(self, loss_mask, loss_mask_sum, output_tensor): loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum # sequence level nll - torch.distributed.all_reduce(loss, group=parallel_state.get_sequence_parallel_group()) + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) return loss def build_train_valid_test_datasets(self): @@ -1053,7 +1053,7 @@ def setup(self, stage=None): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() - self.setup_transformer_engine_sp_running() + self.setup_transformer_engine_cp_running() def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): @@ -1110,7 +1110,7 @@ def dummy(): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() - self.setup_transformer_engine_sp_running() + self.setup_transformer_engine_cp_running() # set the default sampling params if it is None. # default do greedy sampling @@ -1206,34 +1206,34 @@ def setup_transformer_engine_tp_groups(self): else: self._set_tp_groups(self.model) - def _set_sp_running(self, module): - """ Helper method to set sp running for transformer engine""" + def _set_cp_running(self, module): + """ Helper method to set cp running for transformer engine""" if self.cfg.get('transformer_engine', False): - logging.info(f'Setting up transformer engine modules for seequence parallelism.') - sp_stream = torch.cuda.Stream() + logging.info(f'Setting up transformer engine modules for context parallelism.') + cp_stream = torch.cuda.Stream() if self.cfg.get('megatron_amp_O2', 'False'): # when using O2 additional module key is added that casts the weights for layer in module.module.language_model.encoder.layers: - layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), - parallel_state.get_sequence_parallel_global_ranks(), - sp_stream) + layer.set_context_parallel_running(parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream) else: for layer in module.language_model.encoder.layers: - layer.set_sequence_parallel_running(parallel_state.get_sequence_parallel_group(), - parallel_state.get_sequence_parallel_global_ranks(), - sp_stream) + layer.set_context_parallel_running(parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream) - def setup_transformer_engine_sp_running(self): - """ This should be called after model parallel groups have been initialized + def setup_transformer_engine_cp_running(self): + """ This should be called after context parallel groups have been initialized and only needs to be called when using Transformer Engine. """ if isinstance(self.model, list): for module in self.model: - self._set_sp_running(module) + self._set_cp_running(module) else: - self._set_sp_running(self.model) + self._set_cp_running(self.model) def on_save_checkpoint(self, checkpoint) -> None: """LightningModule hook: diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 0da95c343ff4..b153f2e2aff1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -63,7 +63,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, - sequence_parallel_size=1, + context_parallel_size=1, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -83,7 +83,7 @@ def initialize_model_parallel_for_nemo( app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size - app_state.sequence_parallel_size = sequence_parallel_size + app_state.context_parallel_size = context_parallel_size app_state.use_fp8 = use_fp8 ( app_state.tensor_model_parallel_rank, @@ -99,7 +99,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_size_=pipeline_model_parallel_size, virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, - sequence_parallel_size_=sequence_parallel_size, + context_parallel_size_=context_parallel_size, ) # update apex.transformer globals @@ -177,7 +177,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size_, pipeline_model_parallel_split_rank_=None, virtual_pipeline_model_parallel_size_=None, - sequence_parallel_size_=1, + context_parallel_size_=1, ): """ Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized. @@ -188,7 +188,7 @@ def fake_initialize_model_parallel( Arguments: tensor_model_parallel_size: number of GPUs used to parallelize model tensor. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. - sequence_parallel_size: number of GPUs used to parallelize tokens of each input sequence. + context_parallel_size: number of GPUs used to parallelize tokens of each input. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize @@ -211,12 +211,12 @@ def fake_initialize_model_parallel( tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size - sequence_parallel_size = min(sequence_parallel_size_, world_size) + context_parallel_size = min(context_parallel_size_, world_size) assert ( - world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times sequence_parallel_size {sequence_parallel_size}' - data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) + world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 + ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' + data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size @@ -237,30 +237,30 @@ def fake_initialize_model_parallel( data_parallel_group = list(ranks) logging.info(f'Rank {rank} has data parallel group: {data_parallel_group}') - data_parallel_rank = data_parallel_group.index(rank) // sequence_parallel_size + data_parallel_rank = data_parallel_group.index(rank) // context_parallel_size logging.info(f'All data parallel group ranks: {all_data_parallel_group_ranks}') logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') - # Build the sequence-parallel groups. - all_sequence_parallel_group_ranks = [] + # Build the context-parallel groups. + all_context_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): for j in range(data_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * sequence_parallel_size - end_rank = i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * sequence_parallel_size + start_rank = i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size + end_rank = i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size for k in range(tensor_model_parallel_size): ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) - all_sequence_parallel_group_ranks.append(list(ranks)) + all_context_parallel_group_ranks.append(list(ranks)) if rank in ranks: - sequence_parallel_group = list(ranks) - logging.info(f'Rank {rank} has sequence parallel group: {sequence_parallel_group}') + context_parallel_group = list(ranks) + logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}') - sequence_parallel_rank = sequence_parallel_group.index(rank) - logging.info(f'All sequence parallel group ranks: {all_sequence_parallel_group_ranks}') - logging.info(f'Ranks {rank} has sequence parallel rank: {sequence_parallel_rank}') + context_parallel_rank = context_parallel_group.index(rank) + logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}') + logging.info(f'Ranks {rank} has context parallel rank: {context_parallel_rank}') # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for i in range(data_parallel_size * sequence_parallel_size): + for i in range(data_parallel_size * context_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] all_model_parallel_group_ranks.append(ranks) if rank in ranks: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index c7bcf22b5b99..d2f52d47b7d0 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -169,7 +169,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, use_fp8=app_state.use_fp8, - sequence_parallel_size=app_state.sequence_parallel_size, + context_parallel_size=app_state.context_parallel_size, ) # assert that fake tp and pp rank match after model parallel init diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 3af7f9a92ac8..77a8d44916b5 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -55,7 +55,7 @@ def __init__(self): self._data_parallel_group = None self._megatron_checkpoint_version = None self._use_fp8 = False - self._sequence_parallel_size = None + self._context_parallel_size = None self._random_seed = None @@ -365,20 +365,20 @@ def use_fp8(self, use_fp8): self._use_fp8 = use_fp8 @property - def sequence_parallel_size(self): - """ Property returns the number of GPUs in each sequence parallel group. + def context_parallel_size(self): + """ Property returns the number of GPUs in each context parallel group. Returns: - Number of GPUs in each sequence parallel group. + Number of GPUs in each context parallel group. """ - return self._sequence_parallel_size + return self._context_parallel_size - @sequence_parallel_size.setter - def sequence_parallel_size(self, size): - """ Property sets the number of GPUs in each sequence parallel group. + @context_parallel_size.setter + def context_parallel_size(self, size): + """ Property sets the number of GPUs in each context parallel group. Args: - size (int): Number of GPUs in each sequence parallel group. + size (int): Number of GPUs in each context parallel group. """ - self._sequence_parallel_size = size + self._context_parallel_size = size @property def random_seed(self): From f7c9b5b91efcdfbe82b93cc9c20be99b797d2f25 Mon Sep 17 00:00:00 2001 From: xren Date: Sat, 24 Jun 2023 15:10:51 -0700 Subject: [PATCH 16/47] minor change Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5ea0df702122..3b42f3c725a9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -750,7 +750,7 @@ def __next__(self): def get_batch_on_this_context_parallel_rank(self, batch): cp_size = self.cfg.get('context_parallel_size', 1) - loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() + if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): @@ -761,6 +761,8 @@ def get_batch_on_this_context_parallel_rank(self, batch): val = val.index_select(seq_dim, index) val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val + + loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() batch['loss_mask_sum'] = loss_mask_sum return batch From 49b1052dc2425de41327770d23514d877052504a Mon Sep 17 00:00:00 2001 From: xren Date: Sat, 24 Jun 2023 15:24:39 -0700 Subject: [PATCH 17/47] fix loss_mask_sum calculation Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 3b42f3c725a9..40b141b5d697 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -750,6 +750,7 @@ def __next__(self): def get_batch_on_this_context_parallel_rank(self, batch): cp_size = self.cfg.get('context_parallel_size', 1) + loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() @@ -762,7 +763,6 @@ def get_batch_on_this_context_parallel_rank(self, batch): val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) batch[key] = val - loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() batch['loss_mask_sum'] = loss_mask_sum return batch From 2c43687a19cec346b730f8dabe5f352965cc88f1 Mon Sep 17 00:00:00 2001 From: xren Date: Wed, 2 Aug 2023 17:10:26 -0700 Subject: [PATCH 18/47] make sure do not call megatron-core parallel_state while cp_size is 1 Signed-off-by: xren --- .../language_modeling/megatron_gpt_model.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 0817ef9e6b29..2c407dedac74 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -799,7 +799,7 @@ def get_batch_on_this_context_parallel_rank(self, batch): val = val.view(*val.shape[0:seq_dim], 2*cp_size, val.shape[seq_dim]//(2*cp_size), *val.shape[(seq_dim+1):]) index = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=val.device) val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], len(index)*val.shape[seq_dim+1], *val.shape[(seq_dim+2):]) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim+2):]) batch[key] = val batch['loss_mask_sum'] = loss_mask_sum @@ -955,7 +955,9 @@ def loss_func(self, loss_mask, loss_mask_sum, output_tensor): loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum # sequence level nll - torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size > 1: + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) return loss def build_train_valid_test_datasets(self): @@ -1256,12 +1258,11 @@ def setup_transformer_engine_tp_groups(self): else: self._set_tp_groups(self.model) - def _set_cp_running(self, module): + def _set_cp_running(self, module, cp_stream): """ Helper method to set cp running for transformer engine""" if self.cfg.get('transformer_engine', False): logging.info(f'Setting up transformer engine modules for context parallelism.') - cp_stream = torch.cuda.Stream() if self.cfg.get('megatron_amp_O2', 'False'): # when using O2 additional module key is added that casts the weights for layer in module.module.language_model.encoder.layers: @@ -1279,11 +1280,17 @@ def setup_transformer_engine_cp_running(self): """ This should be called after context parallel groups have been initialized and only needs to be called when using Transformer Engine. """ + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size == 1: + return + + cp_stream = torch.cuda.Stream() + if isinstance(self.model, list): for module in self.model: - self._set_cp_running(module) + self._set_cp_running(module, cp_stream) else: - self._set_cp_running(self.model) + self._set_cp_running(self.model, cp_stream) def on_save_checkpoint(self, checkpoint) -> None: """LightningModule hook: From 61af55183cc6c91d6de15ad02f0877683ee24775 Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 3 Aug 2023 14:39:53 -0700 Subject: [PATCH 19/47] slice position embedding for different CP rank Signed-off-by: xren --- .../language_modeling/megatron/gpt_model.py | 2 ++ .../language_modeling/megatron_gpt_model.py | 1 + .../modules/common/megatron/language_model.py | 28 +++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index d70c3e06bf01..f7e19f2493fa 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -167,6 +167,7 @@ def __init__( ub_tp_comm_overlap=False, use_flash_attention=False, seq_len_interpolation_factor=None, + context_parallel=False, ): super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -251,6 +252,7 @@ def __init__( ub_tp_comm_overlap=ub_tp_comm_overlap, use_flash_attention=use_flash_attention, seq_len_interpolation_factor=seq_len_interpolation_factor, + context_parallel=context_parallel, ) if self.share_embeddings_and_output_weights: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2c407dedac74..866b228b18cf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -372,6 +372,7 @@ def model_provider_func(self, pre_process, post_process): use_flash_attention=self.cfg.get('use_flash_attention', False), megatron_legacy=self.cfg.get('megatron_legacy', False), seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), + context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) return model diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 2aa2e8a3860e..3c7720d1e0b0 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -124,6 +124,7 @@ def get_language_model( ub_tp_comm_overlap=False, use_flash_attention=False, seq_len_interpolation_factor=None, + context_parallel=False, ): """Build language model and return along with the key to save.""" @@ -202,6 +203,7 @@ def get_language_model( ub_tp_comm_overlap=ub_tp_comm_overlap, use_flash_attention=use_flash_attention, seq_len_interpolation_factor=seq_len_interpolation_factor, + context_parallel=context_parallel, ) # key used for checkpoints. language_model_key = 'language_model' @@ -511,6 +513,7 @@ def __init__( ub_tp_comm_overlap=False, use_flash_attention=False, seq_len_interpolation_factor=None, + context_parallel=False, ): super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -531,6 +534,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = sequence_parallel + self.context_parallel = context_parallel self.dtype = utils_funcs.dtype_from_precision(precision, megatron_amp_O2) if kv_channels is None: @@ -726,6 +730,19 @@ def set_input_tensor(self, input_tensor): self.encoder.set_input_tensor(input_tensor[0]) + def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim) + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=position_embedding.device) + position_embedding = position_embedding.view( + *position_embedding.shape[:seq_dim], 2*cp_size, -1, *position_embedding.shape[(seq_dim+1):] + ) + position_embedding = position_embedding.index_select(seq_dim, cp_idx) + position_embedding = position_embedding.view( + *position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim+2):] + ) + return position_embedding + def forward( self, enc_input_ids, @@ -779,10 +796,16 @@ def forward( else: enc_seq_length = encoder_input.size(0) + if self.context_parallel: + enc_seq_length = enc_seq_length * parallel_state.get_context_parallel_world_size() + rotary_pos_emb = None encoder_self_attention_relative_position_bias = None if self.position_embedding_type == 'rope': rotary_pos_emb = self.rotary_pos_emb(enc_seq_length) + + if self.context_parallel: + rotary_pos_emb = self.get_position_embedding_on_this_context_parallel_rank(rotary_pos_emb, 0) elif ( self.position_embedding_type == 'alibi' or self.position_embedding_type == 'sandwich' @@ -794,6 +817,11 @@ def forward( # causal attention bias: [1, head, 1, k] # non-causal attention bias: [1, head, q, k] + if self.context_parallel and encoder_self_attention_relative_position_bias.shape[-2] > 1: + encoder_self_attention_relative_position_bias = self.get_position_embedding_on_this_context_parallel_rank( + encoder_self_attention_relative_position_bias, -2 + ) + # encoder. if enc_hidden_states is None: encoder_output = self.encoder( From dc8a540671607d8e20c44427bb1b9751e8256ce7 Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 3 Aug 2023 15:52:31 -0700 Subject: [PATCH 20/47] fix mising property decorator Signed-off-by: xren --- nemo/utils/app_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 430dd4a9b9f9..eb6b6d91ba5e 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -381,6 +381,7 @@ def context_parallel_size(self, size): """ self._context_parallel_size = size + @property def init_mpi_proc_group(self): """ Property sets the initialization of mpi process group. Returns: From 46479c6127a4075d462c940d700326f109a530ab Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 3 Aug 2023 16:00:32 -0700 Subject: [PATCH 21/47] typo fix Signed-off-by: xren --- nemo/collections/nlp/modules/common/megatron/language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 3c7720d1e0b0..9f12b7a20986 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -730,7 +730,7 @@ def set_input_tensor(self, input_tensor): self.encoder.set_input_tensor(input_tensor[0]) - def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim) + def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim): cp_size = parallel_state.get_context_parallel_world_size() cp_rank = parallel_state.get_context_parallel_rank() cp_idx = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=position_embedding.device) From b64b5635677d722c51ae0d48810ce3593dc5018b Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 3 Aug 2023 18:42:09 -0700 Subject: [PATCH 22/47] fix rpe_bias CP slicing Signed-off-by: xren --- nemo/collections/nlp/modules/common/megatron/language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 9f12b7a20986..d702a8a3e11f 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -819,7 +819,7 @@ def forward( if self.context_parallel and encoder_self_attention_relative_position_bias.shape[-2] > 1: encoder_self_attention_relative_position_bias = self.get_position_embedding_on_this_context_parallel_rank( - encoder_self_attention_relative_position_bias, -2 + encoder_self_attention_relative_position_bias, 2 ) # encoder. From e1654fb045a33a401a043a01cad9768ffedfc8f8 Mon Sep 17 00:00:00 2001 From: xren Date: Sat, 5 Aug 2023 18:05:12 -0700 Subject: [PATCH 23/47] code style fix Signed-off-by: xren --- .../language_modeling/megatron_gpt_model.py | 35 +++++++++++++------ .../modules/common/megatron/language_model.py | 6 ++-- .../modules/common/megatron/megatron_init.py | 12 +++++-- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index bf48e81429fb..967b9a6859ef 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -460,7 +460,11 @@ def forward(self, tokens, text_position_ids, attention_mask, labels): return output_tensor def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): - tensor_shape = [self.cfg.encoder_seq_length // self.cfg.get('context_parallel_size', 1), self.cfg.micro_batch_size, self.cfg.hidden_size] + tensor_shape = [ + self.cfg.encoder_seq_length // self.cfg.get('context_parallel_size', 1), + self.cfg.micro_batch_size, + self.cfg.hidden_size, + ] # handle asynchronous grad reduction no_sync_func = None @@ -797,10 +801,15 @@ def get_batch_on_this_context_parallel_rank(self, batch): for key, val in batch.items(): if val is not None: seq_dim = 1 if key != 'attnetion_mask' else 2 - val = val.view(*val.shape[0:seq_dim], 2*cp_size, val.shape[seq_dim]//(2*cp_size), *val.shape[(seq_dim+1):]) - index = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=val.device) + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device) val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim+2):]) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) batch[key] = val batch['loss_mask_sum'] = loss_mask_sum @@ -863,7 +872,7 @@ def loss_func(output_tensor): return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub*cp_size, {'avg': reduced_loss} + return loss_for_ub * cp_size, {'avg': reduced_loss} return output_tensor, loss_func @@ -1274,15 +1283,19 @@ def _set_cp_running(self, module, cp_stream): if self.cfg.get('megatron_amp_O2', 'False'): # when using O2 additional module key is added that casts the weights for layer in module.module.language_model.encoder.layers: - layer.set_context_parallel_running(parallel_state.get_context_parallel_group(), - parallel_state.get_context_parallel_global_ranks(), - cp_stream) + layer.set_context_parallel_running( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) else: for layer in module.language_model.encoder.layers: - layer.set_context_parallel_running(parallel_state.get_context_parallel_group(), - parallel_state.get_context_parallel_global_ranks(), - cp_stream) + layer.set_context_parallel_running( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) def setup_transformer_engine_cp_running(self): """ This should be called after context parallel groups have been initialized diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index d702a8a3e11f..c978737ae00c 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -733,13 +733,13 @@ def set_input_tensor(self, input_tensor): def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim): cp_size = parallel_state.get_context_parallel_world_size() cp_rank = parallel_state.get_context_parallel_rank() - cp_idx = torch.tensor([cp_rank, (2*cp_size-cp_rank-1)], device=position_embedding.device) + cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=position_embedding.device) position_embedding = position_embedding.view( - *position_embedding.shape[:seq_dim], 2*cp_size, -1, *position_embedding.shape[(seq_dim+1):] + *position_embedding.shape[:seq_dim], 2 * cp_size, -1, *position_embedding.shape[(seq_dim + 1) :] ) position_embedding = position_embedding.index_select(seq_dim, cp_idx) position_embedding = position_embedding.view( - *position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim+2):] + *position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim + 2) :] ) return position_embedding diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 49654e35db62..0515d4aa2974 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -218,7 +218,9 @@ def fake_initialize_model_parallel( assert ( world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' - data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) + data_parallel_size = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + ) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size @@ -247,8 +249,12 @@ def fake_initialize_model_parallel( all_context_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): for j in range(data_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size - end_rank = i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size + start_rank = ( + i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size + ) + end_rank = ( + i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size + ) for k in range(tensor_model_parallel_size): ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) all_context_parallel_group_ranks.append(list(ranks)) From 4f0a3bed20226fa7f160914b0fcaec63afc05cbb Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 8 Aug 2023 16:16:51 -0700 Subject: [PATCH 24/47] fix loss_mask_sum calculation Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 967b9a6859ef..ea33cbbd2897 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -794,7 +794,9 @@ def __next__(self): def get_batch_on_this_context_parallel_rank(self, batch): cp_size = self.cfg.get('context_parallel_size', 1) - loss_mask_sum = None if 'loss_mask' not in batch else batch['loss_mask'].sum() + loss_mask_sum = None + if 'loss_mask' in batch and batch['loss_mask'] is not None: + loss_mask_sum = batch['loss_mask'].sum() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() From 4076d06a3e71e34f022b3c1d6e7a200ed239000d Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 21 Aug 2023 22:33:22 -0700 Subject: [PATCH 25/47] do not load attention mask if it's not needed Signed-off-by: Xiaowei Ren --- .../language_modeling/megatron/gpt_dataset.py | 46 +++++++++++++------ .../language_modeling/megatron_gpt_model.py | 2 +- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index d7113e7cdde3..512f5b07f44d 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -313,6 +313,7 @@ def __init__( self.indexed_dataset = indexed_dataset self.drop_last = drop_last self.seq_length = seq_length + self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True) # Checks assert np.min(documents) >= 0 @@ -433,13 +434,21 @@ def __getitem__(self, idx): logging.debug('Got negative index. Masking loss from this sample') loss_mask = torch.zeros_like(loss_mask) - return { - 'tokens': tokens, - 'labels': labels, - 'attention_mask': attention_mask, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } + if self.get_attention_mask_from_fusion: + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + return { + 'tokens': tokens, + 'labels': labels, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } class MockGPTDataset(Dataset): @@ -457,6 +466,7 @@ def __init__( self.vocab_size = tokenizer.vocab_size self.length = num_samples self.seed = seed + self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True) self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0) self.attention_mask = self.attention_mask < 0.5 @@ -476,13 +486,21 @@ def __getitem__(self, idx): tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) - return { - 'tokens': tokens, - 'labels': labels, - 'attention_mask': self.attention_mask, - 'loss_mask': self.loss_mask, - 'position_ids': self.position_ids, - } + if self.get_attention_mask_from_fusion: + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': self.loss_mask, + 'position_ids': self.position_ids, + } + else: + return { + 'tokens': tokens, + 'labels': labels, + 'attention_mask': self.attention_mask, + 'loss_mask': self.loss_mask, + 'position_ids': self.position_ids, + } @torch.no_grad() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ea33cbbd2897..179cf276c053 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -844,7 +844,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ output_tensor = model( batch['tokens'], batch['position_ids'], - batch['attention_mask'], + None if self.get_attention_mask_from_fusion else batch['attention_mask'], batch['labels'], checkpoint_activations_all_layers=checkpoint_activations_all_layers, ) From 433f6a72f309d892942f64bedaac6fa3dccec138 Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 22 Aug 2023 17:02:33 -0700 Subject: [PATCH 26/47] bug fix Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9b0111fe412b..eee91a714cf4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -851,7 +851,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): required_keys.update(('labels', 'loss_mask')) - if self.get_attention_mask_from_fusion: + if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: required_keys.remove('attention_mask') batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} From 5efaa763d72b973fad73676e1a119ec6fcd857f7 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 5 Sep 2023 16:58:23 -0700 Subject: [PATCH 27/47] fix ubuf size with CP > 1 Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index fe5fae302eeb..3acf95f6faf5 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -552,7 +552,9 @@ def initialize_ub_func(self): ) input_shape = [ - self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), + self.cfg.get('encoder_seq_length') + * self.cfg.get('micro_batch_size') + // self.cfg.get('context_parallel_size', 1), self.cfg.get('hidden_size'), ] From 006677d410076e2c92e0de0a697df01703b30bf8 Mon Sep 17 00:00:00 2001 From: xren Date: Wed, 13 Sep 2023 18:20:25 -0700 Subject: [PATCH 28/47] address naming confusion of mixed dp and cp Signed-off-by: xren --- .../modules/common/megatron/megatron_init.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 0515d4aa2974..013838e7688e 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -230,19 +230,28 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_rank = 0 # Build the data-parallel groups. - all_data_parallel_group_ranks = [] + all_data_parallel_group_ranks_with_cp = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) + for j in range(context_parallel_size * tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size) if rank in ranks: data_parallel_group = list(ranks) - logging.info(f'Rank {rank} has data parallel group: {data_parallel_group}') + logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') + for j in range(tensor_model_parallel_size): + ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) + if rank in ranks_with_cp: + data_parallel_group_with_cp = list(ranks_with_cp) + logging.info( + f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}' + ) - data_parallel_rank = data_parallel_group.index(rank) // context_parallel_size - logging.info(f'All data parallel group ranks: {all_data_parallel_group_ranks}') + data_parallel_rank = data_parallel_group.index(rank) + logging.info( + f'All data parallel group ranks with context parallel combined: {all_data_parallel_group_ranks_with_cp}' + ) logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') # Build the context-parallel groups. @@ -269,7 +278,10 @@ def fake_initialize_model_parallel( # Build the model-parallel groups. all_model_parallel_group_ranks = [] for i in range(data_parallel_size * context_parallel_size): - ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] + ranks = [ + data_parallel_group_ranks_with_cp[i] + for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp + ] all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') From 0f7d0797e5548c87c45d49cdb55b674215d320f3 Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 3 Oct 2023 01:24:47 -0700 Subject: [PATCH 29/47] rewrite cp code by assuming with_context_parallel=False Signed-off-by: xren --- .../megatron/dataset_utils.py | 1 + .../language_modeling/megatron/gpt_dataset.py | 1 + .../language_modeling/megatron_base_model.py | 10 ++++-- .../language_modeling/megatron_gpt_model.py | 4 ++- .../modules/common/megatron/build_model.py | 6 +++- nemo/collections/nlp/parts/nlp_overrides.py | 3 +- nemo/core/optim/distributed_adam.py | 3 +- nemo/core/optim/optimizer_with_main_params.py | 36 ++++++++++++------- 8 files changed, 46 insertions(+), 18 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index f286bb9a8adf..2d5722dc8c62 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -1330,6 +1330,7 @@ def get_samples_mapping( ) torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=parallel_state.get_context_parallel_group()) torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index 512f5b07f44d..4941e2d6b23e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -692,6 +692,7 @@ def _build_index_mappings( torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=parallel_state.get_context_parallel_group()) torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 8c8420ddc15e..352f3f1bbda0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -371,12 +371,15 @@ def allreduce_gradients(self): # param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. + context_parallel = self.cfg.get('context_parallel_size', 1) > 1 for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = torch._utils._flatten_dense_tensors(grads) - coalesced /= parallel_state.get_data_parallel_world_size() - torch.distributed.all_reduce(coalesced, group=parallel_state.get_data_parallel_group()) + coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=context_parallel) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel) + ) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) @@ -451,6 +454,7 @@ def setup_optimization( self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, ): optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() + optim_kwargs['context_parallel'] = self.cfg.get('context_parallel_size', 1) > 1 if self.with_distributed_adam: # Allocate contiguous buffer to avoid extra copies @@ -514,6 +518,7 @@ def configure_optimizers(self): else: grad_div_ar_fusion = False + context_parallel = self.cfg.get('context_parallel_size', 1) > 1 self._optimizer = MainParamsOptimizerWrapper( self._optimizer, fp32_grad_accum=fp32_grad_accum, @@ -521,6 +526,7 @@ def configure_optimizers(self): async_grad_allreduce=async_grad_allreduce, grad_div_ar_fusion=grad_div_ar_fusion, grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125), + context_parallel=context_parallel, ) assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config." diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index aca851033146..22825e12d13e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -230,12 +230,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): wrap_with_ddp=False, on_cpu=True, virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) else: self.model = build_model( model_provider_func=self.model_provider_func, wrap_with_ddp=False, virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) # if we're not using interleaved, then self.model is a module. @@ -880,7 +882,7 @@ def loss_func(output_tensor): loss_for_ub = self.loss_func(batch['loss_mask'], batch['loss_mask_sum'], output_tensor) cp_size = self.cfg.get('context_parallel_size', 1) if validation_step and not self.cfg.data.get('validation_drop_last', True): - num_valid_tokens_in_ub = batch['loss_mask'].sum() + num_valid_tokens_in_ub = batch['loss_mask_sum'] if loss_for_ub.isnan(): assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) diff --git a/nemo/collections/nlp/modules/common/megatron/build_model.py b/nemo/collections/nlp/modules/common/megatron/build_model.py index 929093405fce..bbcdeed9c089 100644 --- a/nemo/collections/nlp/modules/common/megatron/build_model.py +++ b/nemo/collections/nlp/modules/common/megatron/build_model.py @@ -48,6 +48,7 @@ def build_model( virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, on_cpu: bool = False, + context_parallel: bool = False, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: @@ -151,7 +152,10 @@ def build_model( i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( - model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), + model_module, + device_ids=[i], + output_device=i, + process_group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel), ) for model_module in model ] diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5d5ba812575e..3390bb7241a1 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -156,9 +156,10 @@ def configure_ddp(self): # to False in PTL 2.0 and hence pre_configure_ddp() is removed in ddp.py # self.pre_configure_ddp() # device_ids = self.determine_ddp_device_ids() + context_parallel = app_state.context_parallel_size is not None and app_state.context_parallel_size > 1 self._model = DistributedDataParallel( _LightningModuleWrapperBase(self.model), - process_group=parallel_state.get_data_parallel_group(), + process_group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel), **self._ddp_kwargs, ) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index d7bc049c1808..563374aee941 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -73,7 +73,8 @@ def __init__( # Initialize process groups if 'process_group' not in kwargs and not parallel_state.is_unitialized(): - kwargs['process_group'] = parallel_state.get_data_parallel_group() + context_parallel = kwargs['context_parallel'] if 'context_parallel' in kwargs else False + kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=context_parallel) if disable_distributed_parameters: world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 922412f1e8a6..682dd2877142 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -105,10 +105,10 @@ def zero(self): """Reset the buffer to zero.""" self.data.zero_() - def allreduce_buffer(self): + def allreduce_buffer(self, context_parallel=False): """Synchronous buffer data allreduce """ - self.data.div_(get_data_parallel_world_size()) - torch.distributed.all_reduce(self.data, group=get_data_parallel_group()) + self.data.div_(get_data_parallel_world_size(with_context_parallel=context_parallel)) + torch.distributed.all_reduce(self.data, group=get_data_parallel_group(with_context_parallel=context_parallel)) def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the @@ -176,6 +176,7 @@ def __init__( async_grad_allreduce=False, grad_div_ar_fusion=True, grad_allreduce_chunk_size_mb=0, + context_parallel=False, ): if not HAVE_APEX: raise ImportError( @@ -202,11 +203,14 @@ def __init__( self._fp32_grad_accum = fp32_grad_accum self._contiguous_grad_bucket = contiguous_grad_bucket + self._context_parallel = context_parallel # used with tensor parallel only (no pipeline parallelism) # be careful, weight update cannot start until all async grad AR works are done - self._async_grad_allreduce = async_grad_allreduce and get_data_parallel_world_size() > 1 - self._grad_divisor = 1 / get_data_parallel_world_size() + self._async_grad_allreduce = ( + async_grad_allreduce and get_data_parallel_world_size(with_context_parallel=self._context_parallel) > 1 + ) + self._grad_divisor = 1 / get_data_parallel_world_size(with_context_parallel=self._context_parallel) if self._async_grad_allreduce: # use @no_sync to disable backward grad sync during gradient accumulation @@ -340,27 +344,35 @@ def param_hook(*unused): if self._grad_div_ar_fusion: torch.distributed.all_reduce( allreduce_tensor, - group=get_data_parallel_group(), + group=get_data_parallel_group(with_context_parallel=self._context_parallel), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - allreduce_tensor.div_(get_data_parallel_world_size()) + allreduce_tensor.div_( + get_data_parallel_world_size(with_context_parallel=self._context_parallel) + ) torch.distributed.all_reduce( - allreduce_tensor, group=get_data_parallel_group(), async_op=True, + allreduce_tensor, + group=get_data_parallel_group(with_context_parallel=self._context_parallel), + async_op=True, ) else: if self._grad_div_ar_fusion: torch.distributed.all_reduce( main_param.grad, - group=get_data_parallel_group(), + group=get_data_parallel_group(with_context_parallel=self._context_parallel), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - main_param.grad.div_(get_data_parallel_world_size()) + main_param.grad.div_( + get_data_parallel_world_size(with_context_parallel=self._context_parallel) + ) torch.distributed.all_reduce( - main_param.grad, group=get_data_parallel_group(), async_op=True, + main_param.grad, + group=get_data_parallel_group(with_context_parallel=self._context_parallel), + async_op=True, ) return param_hook @@ -470,7 +482,7 @@ def load_state_dict(self, state_dict): def allreduce_main_grads(self): for i in self._main_grad_buffers: - self._main_grad_buffers[i].allreduce_buffer() + self._main_grad_buffers[i].allreduce_buffer(context_parallel=self._context_parallel) @contextmanager def no_sync(self): From 335195314ac4251243cf03679af24d4d5f75778b Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 3 Oct 2023 15:05:56 -0700 Subject: [PATCH 30/47] pop context_parallel from dist opt kwargs Signed-off-by: xren --- nemo/core/optim/distributed_adam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 563374aee941..7a319d21a0e7 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -75,6 +75,7 @@ def __init__( if 'process_group' not in kwargs and not parallel_state.is_unitialized(): context_parallel = kwargs['context_parallel'] if 'context_parallel' in kwargs else False kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=context_parallel) + kwargs.pop('context_parallel', False) if disable_distributed_parameters: world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() From 08f785b064548c3c97e4168f9a59d3f645063ebc Mon Sep 17 00:00:00 2001 From: xren Date: Wed, 4 Oct 2023 23:47:49 -0700 Subject: [PATCH 31/47] make sure amax reduction group is aware of context parallelism Signed-off-by: xren --- nemo/collections/nlp/modules/common/megatron/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 06a2e306482e..ac932f713dc7 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1462,7 +1462,8 @@ def forward( # fp8_autocast will not do anything if TE or FP8 isn't used fp8_group = None if self.fp8 and parallel_state.model_parallel_is_initialized(): - fp8_group = parallel_state.get_amax_reduction_group() + cp_size = parallel_state.get_context_parallel_world_size() + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=(cp_size > 1)) if HAVE_TE: # if TE is installed but fp8 is not available then this will do nothing From e277b3d0f300a96bd901d9479458e5704397bfa5 Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 5 Oct 2023 00:47:12 -0700 Subject: [PATCH 32/47] remove use_fp8 from initialize_model_parallel Signed-off-by: xren --- nemo/collections/nlp/parts/nlp_overrides.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 3390bb7241a1..0dbfafd2a3da 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -196,7 +196,6 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, - use_fp8=app_state.use_fp8, context_parallel_size=app_state.context_parallel_size, ) From dc65d346bd3f376a078ee768f3cbe45560805ed8 Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 5 Oct 2023 19:48:16 -0700 Subject: [PATCH 33/47] make implementaitons of setup_transformer_engine_tp_groups and setup_transformer_engine_cp_running consistent Signed-off-by: xren --- .../language_modeling/megatron_gpt_model.py | 45 ++++++------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 7c8de9e68a14..bd5542fc1b19 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1349,43 +1349,26 @@ def setup_transformer_engine_tp_groups(self): tp_group = parallel_state.get_tensor_model_parallel_group() child.set_tensor_parallel_group(tp_group) - def _set_cp_running(self, module, cp_stream): - """ Helper method to set cp running for transformer engine""" - - if self.cfg.get('transformer_engine', False): - logging.info(f'Setting up transformer engine modules for context parallelism.') - if self.cfg.get('megatron_amp_O2', 'False'): - # when using O2 additional module key is added that casts the weights - for layer in module.module.language_model.encoder.layers: - layer.set_context_parallel_running( - parallel_state.get_context_parallel_group(), - parallel_state.get_context_parallel_global_ranks(), - cp_stream, - ) - - else: - for layer in module.language_model.encoder.layers: - layer.set_context_parallel_running( - parallel_state.get_context_parallel_group(), - parallel_state.get_context_parallel_global_ranks(), - cp_stream, - ) - def setup_transformer_engine_cp_running(self): """ This should be called after context parallel groups have been initialized and only needs to be called when using Transformer Engine. """ - cp_size = self.cfg.get('context_parallel_size', 1) - if cp_size == 1: - return - cp_stream = torch.cuda.Stream() - if isinstance(self.model, list): - for module in self.model: - self._set_cp_running(module, cp_stream) - else: - self._set_cp_running(self.model, cp_stream) + for module in self.get_gpt_module_list(): + """Set context parallel running + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + """ + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(module.modules()): + if index == 0: + continue + if hasattr(child, "set_context_parallel_running"): + child.set_context_parallel_running( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) def on_save_checkpoint(self, checkpoint) -> None: """LightningModule hook: From 50131894e6199358c6980bbe60d8037bd9e42ca3 Mon Sep 17 00:00:00 2001 From: xren Date: Tue, 10 Oct 2023 18:01:25 -0700 Subject: [PATCH 34/47] cp function renaming Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5fc20fd8353a..ed46ba77f122 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1198,7 +1198,7 @@ def setup(self, stage=None): if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False): self.setup_transformer_engine_tp_groups() - self.setup_transformer_engine_cp_running() + self.setup_transformer_engine_cp_groups() def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): @@ -1257,7 +1257,7 @@ def dummy(): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() - self.setup_transformer_engine_cp_running() + self.setup_transformer_engine_cp_groups() # set the default sampling params if it is None. # default do greedy sampling @@ -1349,7 +1349,7 @@ def setup_transformer_engine_tp_groups(self): tp_group = parallel_state.get_tensor_model_parallel_group() child.set_tensor_parallel_group(tp_group) - def setup_transformer_engine_cp_running(self): + def setup_transformer_engine_cp_groups(self): """ This should be called after context parallel groups have been initialized and only needs to be called when using Transformer Engine. """ @@ -1363,8 +1363,8 @@ def setup_transformer_engine_cp_running(self): for index, child in enumerate(module.modules()): if index == 0: continue - if hasattr(child, "set_context_parallel_running"): - child.set_context_parallel_running( + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group( parallel_state.get_context_parallel_group(), parallel_state.get_context_parallel_global_ranks(), cp_stream, From 52dd50bee1668c80c76c4cdaff04e33b86966bdd Mon Sep 17 00:00:00 2001 From: xren Date: Thu, 12 Oct 2023 19:57:43 -0700 Subject: [PATCH 35/47] make loss logging broadcast aware of cp Signed-off-by: xren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ed46ba77f122..6b5e7fdd9461 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -643,10 +643,11 @@ def training_step(self, dataloader_iter, batch_idx): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if self.loss_broadcast_src_rank is None: dp_size = parallel_state.get_data_parallel_world_size() + cp_size = parallel_state.get_context_parallel_world_size() tp_size = parallel_state.get_tensor_model_parallel_world_size() pp_size = parallel_state.get_pipeline_model_parallel_world_size() - rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) - last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) + rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size) + last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1) self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group torch.distributed.broadcast( loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), @@ -1017,10 +1018,11 @@ def on_validation_epoch_end(self): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if self.loss_broadcast_src_rank is None: dp_size = parallel_state.get_data_parallel_world_size() + cp_size = parallel_state.get_context_parallel_world_size() tp_size = parallel_state.get_tensor_model_parallel_world_size() pp_size = parallel_state.get_pipeline_model_parallel_world_size() - rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) - last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) + rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size) + last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1) self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group torch.distributed.broadcast( averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), From 52381e8209d863e2ea74e18df15f9fb362daa2e2 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Fri, 13 Oct 2023 15:15:24 -0700 Subject: [PATCH 36/47] fix a typo Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ed46ba77f122..ee681daf1cf4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -836,7 +836,7 @@ def get_batch_on_this_context_parallel_rank(self, batch): cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): if val is not None: - seq_dim = 1 if key != 'attnetion_mask' else 2 + seq_dim = 1 if key != 'attention_mask' else 2 val = val.view( *val.shape[0:seq_dim], 2 * cp_size, From e39439204e7f89642051d8eb8312a0b1ed0e5646 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Fri, 13 Oct 2023 18:04:50 -0700 Subject: [PATCH 37/47] var name fix Signed-off-by: Xiaowei Ren --- .../language_modeling/megatron_gpt_model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 455b4040f7d9..999c1fef9d85 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -829,9 +829,9 @@ def __next__(self): def get_batch_on_this_context_parallel_rank(self, batch): cp_size = self.cfg.get('context_parallel_size', 1) - loss_mask_sum = None + num_valid_tokens_in_ub = None if 'loss_mask' in batch and batch['loss_mask'] is not None: - loss_mask_sum = batch['loss_mask'].sum() + num_valid_tokens_in_ub = batch['loss_mask'].sum() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() @@ -849,7 +849,7 @@ def get_batch_on_this_context_parallel_rank(self, batch): val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) batch[key] = val - batch['loss_mask_sum'] = loss_mask_sum + batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub return batch @@ -895,10 +895,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) - loss_for_ub = self.loss_func(batch['loss_mask'], batch['loss_mask_sum'], output_tensor) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) cp_size = self.cfg.get('context_parallel_size', 1) if validation_step and not self.cfg.data.get('validation_drop_last', True): - num_valid_tokens_in_ub = batch['loss_mask_sum'] + num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) @@ -915,7 +915,7 @@ def loss_func(output_tensor): torch.distributed.all_reduce( loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() ) - return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) return loss_for_ub * cp_size, {'avg': reduced_loss} @@ -1041,11 +1041,11 @@ def on_test_epoch_end(self): logging.info(f'test_loss: {averaged_loss[0]}') self.test_step_outputs.clear() # free memory - def loss_func(self, loss_mask, loss_mask_sum, output_tensor): + def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum # sequence level nll + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll cp_size = self.cfg.get('context_parallel_size', 1) if cp_size > 1: torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) From f9bf0d8e383247c6b05c9ba1268ecf414f0715d4 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 16 Oct 2023 16:54:53 -0700 Subject: [PATCH 38/47] import transformer layer specs from MCore Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 999c1fef9d85..74581e626a2d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -76,6 +76,7 @@ try: from megatron.core import InferenceParams, parallel_state from megatron.core.models.gpt import GPTModel as MCoreGPTModel + from megatron.core.models.gpt.gpt_layer_specs import gpt_layer_with_transformer_engine_spec from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig @@ -310,6 +311,7 @@ def model_provider_func(self, pre_process, post_process): if self.mcore_gpt: model = MCoreGPTModel( config=self.transformer_config, + transformer_layer_spec=gpt_layer_with_transformer_engine_spec, vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, From 1f8815f8c11468757293ac0489054062dfa07a01 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 16 Oct 2023 17:05:42 -0700 Subject: [PATCH 39/47] upgrade MCore version Signed-off-by: Xiaowei Ren --- Dockerfile | 2 +- Jenkinsfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 06f96a091a22..96e64219df81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,7 +46,7 @@ WORKDIR /workspace/ # install megatron core, this can be removed once 0.3 pip package is released RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout ab0336a5c8eab77aa74ae604ba1e73decbf6d560 && \ + git checkout 954a65b04c01a4986adbad2a7cc9e9a2d094dd77 && \ pip install -e . WORKDIR /tmp/ diff --git a/Jenkinsfile b/Jenkinsfile index 3d262931915b..c4c9c52d8a95 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -59,11 +59,11 @@ pipeline { stage('Megatron Core installation') { steps { - // pinned MCore https://github.com/NVIDIA/Megatron-LM/commit/ab0336a5c8eab77aa74ae604ba1e73decbf6d560 + // pinned MCore https://github.com/NVIDIA/Megatron-LM/commit/954a65b04c01a4986adbad2a7cc9e9a2d094dd77 // ToT for 23.08 branch sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout ab0336a5c8eab77aa74ae604ba1e73decbf6d560 && \ + git checkout 954a65b04c01a4986adbad2a7cc9e9a2d094dd77 && \ pip install -e .' } } From d15ae1704c6f89a8c51b94a56c5aa678d09a3610 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 17 Oct 2023 14:38:28 -0700 Subject: [PATCH 40/47] add add context_parallel into the kwargs of dist opt Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 5be28b83b4b2..17130c1931a7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -456,8 +456,8 @@ def setup_optimization( self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, ): optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() - optim_kwargs['context_parallel'] = self.cfg.get('context_parallel_size', 1) > 1 if self.with_distributed_adam: + optim_kwargs['context_parallel'] = self.cfg.get('context_parallel_size', 1) > 1 # Allocate contiguous buffer to avoid extra copies optim_kwargs['contiguous_grad_buffer'] = True From 55b7e13db77f6f543a501018b6c24671c71055a2 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Oct 2023 20:16:08 -0700 Subject: [PATCH 41/47] remove redundant cp check Signed-off-by: Xiaowei Ren --- .../megatron/dataset_utils.py | 3 +-- .../language_modeling/megatron/gpt_dataset.py | 3 +-- .../language_modeling/megatron/gpt_model.py | 2 -- .../language_modeling/megatron_base_model.py | 9 ++----- .../language_modeling/megatron_gpt_model.py | 3 --- .../modules/common/megatron/build_model.py | 3 +-- .../modules/common/megatron/language_model.py | 5 +--- .../modules/common/megatron/transformer.py | 3 +-- nemo/collections/nlp/parts/nlp_overrides.py | 3 +-- nemo/core/optim/distributed_adam.py | 4 +-- nemo/core/optim/optimizer_with_main_params.py | 26 +++++++++---------- 11 files changed, 21 insertions(+), 43 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index 2d5722dc8c62..17ffc01fb7f4 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -1330,8 +1330,7 @@ def get_samples_mapping( ) torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=parallel_state.get_context_parallel_group()) - torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( torch.distributed.get_world_size() diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index 1d5842a50cf1..b7fec4f38e1e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -692,8 +692,7 @@ def _build_index_mappings( torch.distributed.barrier() counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=parallel_state.get_context_parallel_group()) - torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) assert counts[0].item() == ( torch.distributed.get_world_size() diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index bf7195d83f28..9b3bae2177b0 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -168,7 +168,6 @@ def __init__( use_flash_attention=False, seq_len_interpolation_factor=None, rotary_base=10000, - context_parallel=False, ): super(GPTModel, self).__init__(config=config, share_token_embeddings=share_embeddings_and_output_weights) @@ -251,7 +250,6 @@ def __init__( use_flash_attention=use_flash_attention, seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, - context_parallel=context_parallel, ) if self.share_embeddings_and_output_weights: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 17130c1931a7..12b97378f24a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -373,14 +373,13 @@ def allreduce_gradients(self): # param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. - context_parallel = self.cfg.get('context_parallel_size', 1) > 1 for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = torch._utils._flatten_dense_tensors(grads) - coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=context_parallel) + coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=True) torch.distributed.all_reduce( - coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel) + coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=True) ) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) @@ -457,8 +456,6 @@ def setup_optimization( ): optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() if self.with_distributed_adam: - optim_kwargs['context_parallel'] = self.cfg.get('context_parallel_size', 1) > 1 - # Allocate contiguous buffer to avoid extra copies optim_kwargs['contiguous_grad_buffer'] = True @@ -520,7 +517,6 @@ def configure_optimizers(self): else: grad_div_ar_fusion = False - context_parallel = self.cfg.get('context_parallel_size', 1) > 1 self._optimizer = MainParamsOptimizerWrapper( self._optimizer, fp32_grad_accum=fp32_grad_accum, @@ -528,7 +524,6 @@ def configure_optimizers(self): async_grad_allreduce=async_grad_allreduce, grad_div_ar_fusion=grad_div_ar_fusion, grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125), - context_parallel=context_parallel, ) assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config." diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f550f4d58103..ce6c76dff760 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -232,14 +232,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): wrap_with_ddp=False, on_cpu=True, virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) else: self.model = build_model( model_provider_func=self.model_provider_func, wrap_with_ddp=False, virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) # if we're not using interleaved, then self.model is a module. @@ -390,7 +388,6 @@ def model_provider_func(self, pre_process, post_process): megatron_legacy=self.cfg.get('megatron_legacy', False), seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), rotary_base=self.cfg.get('rotary_base', 10000), - context_parallel=(self.cfg.get('context_parallel_size', 1) > 1), ) return model diff --git a/nemo/collections/nlp/modules/common/megatron/build_model.py b/nemo/collections/nlp/modules/common/megatron/build_model.py index bbcdeed9c089..2749eae846cd 100644 --- a/nemo/collections/nlp/modules/common/megatron/build_model.py +++ b/nemo/collections/nlp/modules/common/megatron/build_model.py @@ -48,7 +48,6 @@ def build_model( virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, on_cpu: bool = False, - context_parallel: bool = False, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: @@ -155,7 +154,7 @@ def build_model( model_module, device_ids=[i], output_device=i, - process_group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel), + process_group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) for model_module in model ] diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 149c52af4923..522f143f6e34 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -127,7 +127,6 @@ def get_language_model( use_flash_attention=False, seq_len_interpolation_factor=None, rotary_base=10000, - context_parallel=False, ): """Build language model and return along with the key to save.""" @@ -205,7 +204,6 @@ def get_language_model( use_flash_attention=use_flash_attention, seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, - context_parallel=context_parallel, ) # key used for checkpoints. language_model_key = 'language_model' @@ -507,7 +505,6 @@ def __init__( use_flash_attention=False, seq_len_interpolation_factor=None, rotary_base=10000, - context_parallel=False, ): super(TransformerLanguageModel, self).__init__( config=config, share_token_embeddings=share_embeddings_and_output_weights @@ -531,7 +528,7 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = config.sequence_parallel self.dtype = utils_funcs.torch_dtype_from_precision(precision, megatron_amp_O2) - self.context_parallel = context_parallel + self.context_parallel = (parallel_state.get_context_parallel_world_size() > 1) if kv_channels is None: assert ( diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index e91b440ea3b3..f1a536dab491 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1471,8 +1471,7 @@ def forward( # fp8_autocast will not do anything if TE or FP8 isn't used fp8_group = None if self.fp8 and parallel_state.model_parallel_is_initialized(): - cp_size = parallel_state.get_context_parallel_world_size() - fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=(cp_size > 1)) + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) if HAVE_TE: # if TE is installed but fp8 is not available then this will do nothing diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2df1b74fbdda..2e253de9449e 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -156,10 +156,9 @@ def configure_ddp(self): # to False in PTL 2.0 and hence pre_configure_ddp() is removed in ddp.py # self.pre_configure_ddp() # device_ids = self.determine_ddp_device_ids() - context_parallel = app_state.context_parallel_size is not None and app_state.context_parallel_size > 1 self._model = DistributedDataParallel( _LightningModuleWrapperBase(self.model), - process_group=parallel_state.get_data_parallel_group(with_context_parallel=context_parallel), + process_group=parallel_state.get_data_parallel_group(with_context_parallel=True), **self._ddp_kwargs, ) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 7a319d21a0e7..a7baf67b9057 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -73,9 +73,7 @@ def __init__( # Initialize process groups if 'process_group' not in kwargs and not parallel_state.is_unitialized(): - context_parallel = kwargs['context_parallel'] if 'context_parallel' in kwargs else False - kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=context_parallel) - kwargs.pop('context_parallel', False) + kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True) if disable_distributed_parameters: world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 682dd2877142..d6a8bf9f7044 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -105,10 +105,10 @@ def zero(self): """Reset the buffer to zero.""" self.data.zero_() - def allreduce_buffer(self, context_parallel=False): + def allreduce_buffer(self): """Synchronous buffer data allreduce """ - self.data.div_(get_data_parallel_world_size(with_context_parallel=context_parallel)) - torch.distributed.all_reduce(self.data, group=get_data_parallel_group(with_context_parallel=context_parallel)) + self.data.div_(get_data_parallel_world_size(with_context_parallel=True)) + torch.distributed.all_reduce(self.data, group=get_data_parallel_group(with_context_parallel=True)) def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the @@ -176,7 +176,6 @@ def __init__( async_grad_allreduce=False, grad_div_ar_fusion=True, grad_allreduce_chunk_size_mb=0, - context_parallel=False, ): if not HAVE_APEX: raise ImportError( @@ -203,14 +202,13 @@ def __init__( self._fp32_grad_accum = fp32_grad_accum self._contiguous_grad_bucket = contiguous_grad_bucket - self._context_parallel = context_parallel # used with tensor parallel only (no pipeline parallelism) # be careful, weight update cannot start until all async grad AR works are done self._async_grad_allreduce = ( - async_grad_allreduce and get_data_parallel_world_size(with_context_parallel=self._context_parallel) > 1 + async_grad_allreduce and get_data_parallel_world_size(with_context_parallel=True) > 1 ) - self._grad_divisor = 1 / get_data_parallel_world_size(with_context_parallel=self._context_parallel) + self._grad_divisor = 1 / get_data_parallel_world_size(with_context_parallel=True) if self._async_grad_allreduce: # use @no_sync to disable backward grad sync during gradient accumulation @@ -344,34 +342,34 @@ def param_hook(*unused): if self._grad_div_ar_fusion: torch.distributed.all_reduce( allreduce_tensor, - group=get_data_parallel_group(with_context_parallel=self._context_parallel), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: allreduce_tensor.div_( - get_data_parallel_world_size(with_context_parallel=self._context_parallel) + get_data_parallel_world_size(with_context_parallel=True) ) torch.distributed.all_reduce( allreduce_tensor, - group=get_data_parallel_group(with_context_parallel=self._context_parallel), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, ) else: if self._grad_div_ar_fusion: torch.distributed.all_reduce( main_param.grad, - group=get_data_parallel_group(with_context_parallel=self._context_parallel), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: main_param.grad.div_( - get_data_parallel_world_size(with_context_parallel=self._context_parallel) + get_data_parallel_world_size(with_context_parallel=True) ) torch.distributed.all_reduce( main_param.grad, - group=get_data_parallel_group(with_context_parallel=self._context_parallel), + group=get_data_parallel_group(with_context_parallel=True), async_op=True, ) @@ -482,7 +480,7 @@ def load_state_dict(self, state_dict): def allreduce_main_grads(self): for i in self._main_grad_buffers: - self._main_grad_buffers[i].allreduce_buffer(context_parallel=self._context_parallel) + self._main_grad_buffers[i].allreduce_buffer() @contextmanager def no_sync(self): From 840103e8cede7b28baed35066c4e128782bf456d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 03:17:08 +0000 Subject: [PATCH 42/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../nlp/modules/common/megatron/language_model.py | 2 +- nemo/core/optim/optimizer_with_main_params.py | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 522f143f6e34..0602e6d48db6 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -528,7 +528,7 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = config.sequence_parallel self.dtype = utils_funcs.torch_dtype_from_precision(precision, megatron_amp_O2) - self.context_parallel = (parallel_state.get_context_parallel_world_size() > 1) + self.context_parallel = parallel_state.get_context_parallel_world_size() > 1 if kv_channels is None: assert ( diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index d6a8bf9f7044..a809d95a43b6 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -347,9 +347,7 @@ def param_hook(*unused): op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - allreduce_tensor.div_( - get_data_parallel_world_size(with_context_parallel=True) - ) + allreduce_tensor.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( allreduce_tensor, group=get_data_parallel_group(with_context_parallel=True), @@ -364,13 +362,9 @@ def param_hook(*unused): op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - main_param.grad.div_( - get_data_parallel_world_size(with_context_parallel=True) - ) + main_param.grad.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( - main_param.grad, - group=get_data_parallel_group(with_context_parallel=True), - async_op=True, + main_param.grad, group=get_data_parallel_group(with_context_parallel=True), async_op=True, ) return param_hook From 03b2922de3282addab499cbdb05b355b85f191cb Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Oct 2023 20:23:49 -0700 Subject: [PATCH 43/47] code style fix Signed-off-by: Xiaowei Ren --- .../nlp/modules/common/megatron/language_model.py | 2 +- nemo/core/optim/optimizer_with_main_params.py | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 522f143f6e34..0602e6d48db6 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -528,7 +528,7 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = config.sequence_parallel self.dtype = utils_funcs.torch_dtype_from_precision(precision, megatron_amp_O2) - self.context_parallel = (parallel_state.get_context_parallel_world_size() > 1) + self.context_parallel = parallel_state.get_context_parallel_world_size() > 1 if kv_channels is None: assert ( diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index d6a8bf9f7044..a809d95a43b6 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -347,9 +347,7 @@ def param_hook(*unused): op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - allreduce_tensor.div_( - get_data_parallel_world_size(with_context_parallel=True) - ) + allreduce_tensor.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( allreduce_tensor, group=get_data_parallel_group(with_context_parallel=True), @@ -364,13 +362,9 @@ def param_hook(*unused): op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), ) else: - main_param.grad.div_( - get_data_parallel_world_size(with_context_parallel=True) - ) + main_param.grad.div_(get_data_parallel_world_size(with_context_parallel=True)) torch.distributed.all_reduce( - main_param.grad, - group=get_data_parallel_group(with_context_parallel=True), - async_op=True, + main_param.grad, group=get_data_parallel_group(with_context_parallel=True), async_op=True, ) return param_hook From 7c5b9c1c4007a6b1510084e4d883befa7d63a47f Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Oct 2023 20:35:55 -0700 Subject: [PATCH 44/47] recover docker file Signed-off-by: Xiaowei Ren --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 96e64219df81..06f96a091a22 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,7 +46,7 @@ WORKDIR /workspace/ # install megatron core, this can be removed once 0.3 pip package is released RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 954a65b04c01a4986adbad2a7cc9e9a2d094dd77 && \ + git checkout ab0336a5c8eab77aa74ae604ba1e73decbf6d560 && \ pip install -e . WORKDIR /tmp/ From 2da819e536a7a08663005bf6f93282a9e0cecf3d Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 27 Nov 2023 13:56:19 -0800 Subject: [PATCH 45/47] fix seq_length of CP Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f7455f10edce..152b7b82a200 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -514,7 +514,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): model=self.model, num_microbatches=get_num_microbatches(), forward_only=forward_only, - seq_length=self.cfg.encoder_seq_length, + seq_length=(self.cfg.encoder_seq_length // self.cfg.get('context_parallel_size', 1)), micro_batch_size=self.cfg.micro_batch_size, ) From 22eeaf9d363794844c3cb35592258fb1065c72a3 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 4 Dec 2023 15:20:33 -0800 Subject: [PATCH 46/47] recover seq-length which has been fixed in mcore Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 87f507dc3cfb..b9c4b80ca2d1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -515,7 +515,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): model=self.model, num_microbatches=get_num_microbatches(), forward_only=forward_only, - seq_length=(self.cfg.encoder_seq_length // self.cfg.get('context_parallel_size', 1)), + seq_length=self.cfg.encoder_seq_length, micro_batch_size=self.cfg.micro_batch_size, ) From 5d25e671636c384bd46cdc33cb9f24404c91527f Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 18 Dec 2023 17:28:55 -0800 Subject: [PATCH 47/47] function name fix Signed-off-by: Xiaowei Ren --- .../nlp/models/language_modeling/megatron_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index f2a97b848223..38645a4a2a58 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -238,7 +238,7 @@ def setup_transformer_engine_cp_groups(self): """ cp_stream = torch.cuda.Stream() - for module in self.get_gpt_module_list(): + for module in self.get_model_module_list(): """Set context parallel running Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py """