From d50a12508989e67039920bfc063951552b95ec83 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 14 Jul 2025 19:08:12 -0700 Subject: [PATCH 01/81] Start heterogeneous CP idea prototyping --- megatron/core/pipeline_parallel/schedules.py | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 317789ad684..733fb26e152 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -446,6 +446,30 @@ def check_first_val_step(first_val_step, forward_only, cond): else: return cond +def heterogeneous_context_parallel(single_forward_step, total_num_tokens): + def forward_func_wrapper(*args, **kwargs): + # calculate new loop count + # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data_iterator) + N=4 + # TODO: calculate the right cp_assignment for micro-microbatch 0 + # current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0) + # TODO: data["cp_assignment"] = current_cp_assignment + # TODO: create new data_iterator with the cp_assignment + # Run the 1st micro-microbatch + output_tensor, num_tokens = single_forward_step(*args, **kwargs) + total_num_tokens += num_tokens + # Run the N-1 backward steps, N-1 forward steps. + # We will be left with Nth backward step after this loop which is run in the original function. + for i in range(1, N): + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # TODO: calculate the right cp_assignment for micro-microbatch i + # current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) + # TODO: data["cp_assignment"] = cp_assignment + # TODO: create new data_iterator with the cp_assignment + output_tensor, num_tokens = single_forward_step(*args, **kwargs) + total_num_tokens += num_tokens + return output_tensor, total_num_tokens def forward_backward_no_pipelining( *, @@ -495,6 +519,9 @@ def forward_backward_no_pipelining( forward_data_store = [] input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") + if config.heterogeneous_context_parallel: + forward_step = heterogeneous_context_parallel(forward_step, total_num_tokens) + with no_sync_func(): for i in range(num_microbatches - 1): output_tensor, num_tokens = forward_step( @@ -513,6 +540,9 @@ def forward_backward_no_pipelining( if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + # TODO: Does the last complete microbatch should run with sync? + # What will this do when we have multiple micro-microbatches within a forward and we ideally want only sync at end? + # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor, num_tokens = forward_step( From 676f0bf0bd0f6dcc884176481cb3d543b3ef253d Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 16 Jul 2025 12:50:46 -0700 Subject: [PATCH 02/81] Add current_cp_assignment and some supporting code --- megatron/core/pipeline_parallel/schedules.py | 42 +++++++++++++++----- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 733fb26e152..37619085f93 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -446,15 +446,40 @@ def check_first_val_step(first_val_step, forward_only, cond): else: return cond +def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): + ''' + complete_cp_assignment is a list of lists, + Each inner list contains the cp_assignment (assigned GPU ranks) for a sub-sample + This function returns the ith sub-sample assigned to a GPU, None otherwise + For example, complete_cp_assignment = [[0, 1, 2, 3], [4, 5], [4, 5], [6, 7], [6, 7]] + For microbatch_id = 0; rank = 4; current_cp_assignment is [None, [4, 5], None, None, None] + This informs rank 4 that it should pick-up the 2nd sub-sample and share with rank 5 + For microbatch_id = 1; rank = 4; current_cp_assignment is [None, None, [4, 5], None, None] + This informs rank 4 that it should pick-up the 3rd sub-sample and share with rank 5 + ''' + current_cp_assignment = [None] * len(complete_cp_assignment) + matched_sample = -1 + index = None + for i, assigned_ranks in enumerate(complete_cp_assignment): + if rank in assigned_ranks: + matched_sample += 1 + if matched_sample == microbatch_id: + current_cp_assignment[i] = assigned_ranks + break + return current_cp_assignment + def heterogeneous_context_parallel(single_forward_step, total_num_tokens): def forward_func_wrapper(*args, **kwargs): + rank = torch.distributed.get_rank() # TODO: Get the correct rank based on process groups + original_data_iterator = args.data_iterator + data = next(original_data_iterator) # TODO: Protect for model parallelism # calculate new loop count - # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data_iterator) + assert hasattr(data, "packed_seq_metadata"), "data must have a packed_seq_metadata attribute" + # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data["packed_seq_metadata"]) N=4 - # TODO: calculate the right cp_assignment for micro-microbatch 0 - # current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0) - # TODO: data["cp_assignment"] = current_cp_assignment - # TODO: create new data_iterator with the cp_assignment + current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0) + data["cp_assignment"] = current_cp_assignment + args.data_iterator = RerunDataIterator(iter([data])) # Run the 1st micro-microbatch output_tensor, num_tokens = single_forward_step(*args, **kwargs) total_num_tokens += num_tokens @@ -463,10 +488,9 @@ def forward_func_wrapper(*args, **kwargs): for i in range(1, N): backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - # TODO: calculate the right cp_assignment for micro-microbatch i - # current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) - # TODO: data["cp_assignment"] = cp_assignment - # TODO: create new data_iterator with the cp_assignment + current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) + data["cp_assignment"] = current_cp_assignment + args.data_iterator = RerunDataIterator(iter([data])) output_tensor, num_tokens = single_forward_step(*args, **kwargs) total_num_tokens += num_tokens return output_tensor, total_num_tokens From 65a24cc03e017734bbbd07f07c6d7660357d93e4 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 16 Jul 2025 16:08:32 -0700 Subject: [PATCH 03/81] Add bricks for building heterogeneous CP schedules --- megatron/core/pipeline_parallel/schedules.py | 4 +- megatron/core/utils.py | 63 ++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 37619085f93..d4f52320047 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -474,8 +474,8 @@ def forward_func_wrapper(*args, **kwargs): original_data_iterator = args.data_iterator data = next(original_data_iterator) # TODO: Protect for model parallelism # calculate new loop count - assert hasattr(data, "packed_seq_metadata"), "data must have a packed_seq_metadata attribute" - # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data["packed_seq_metadata"]) + assert hasattr(data, "cu_seqlens"), "data must have a cu_seqlens attribute" + # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data["cu_seqlens"], args.config.max_seqlen_per_cp_rank, args.config.context_parallel_size) N=4 current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0) data["cp_assignment"] = current_cp_assignment diff --git a/megatron/core/utils.py b/megatron/core/utils.py index bdbb2627a52..ad82669b570 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -5,6 +5,7 @@ import array import functools import hashlib +import heapq import inspect import logging import math @@ -1862,6 +1863,68 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): return batch +def get_total_workload(seq_length: int): + return seq_length*seq_length + +def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_rank: int, cp_size: int, flops_calculator: Optional[Callable] = get_total_workload): + """ + cu_seqlens: list of sub-sample sequence lengths + max_seqlen_per_cp_rank: list of max sequence length per CP rank + cp_size: total number of CP ranks + flops_calculator: function to calculate flops from cu_seqlens + + Returns: + start_time[j]: the time job j begins + assignment[j]: list of resource IDs assigned to job j + """ + cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in cu_seqlens] + total_workload_per_sample = [flops_calculator(x) for x in cu_seqlens] + total_workload_per_cp_rank = [work/cp_size for work, cp in zip(total_workload_per_sample, cp_size_per_sample)] + n = len(total_workload_per_cp_rank) + # jobs in descending p[j] + jobs = sorted(range(n), key=lambda j: total_workload_per_cp_rank[j], reverse=True) + current_time = 0 + + # a min-heap of free resource IDs + free_resources = list(range(cp_size)) + heapq.heapify(free_resources) + + # events: (release_time, [list of resource IDs freeing then]) + events = [] + + start_time = [None] * n + assignment = [None] * n + + while jobs: + made_progress = True + # try to schedule any job that fits in the currently free resources + while made_progress: + made_progress = False + for j in list(jobs): + if cp_size_per_sample[j] <= len(free_resources): + # grab the lowest‐ID resources available + assigned = [heapq.heappop(free_resources) for _ in range(cp_size_per_sample[j])] + start_time[j] = current_time + assignment[j] = assigned + # schedule their release + release_time = current_time + total_workload_per_cp_rank[j] + heapq.heappush(events, (release_time, assigned)) + jobs.remove(j) + made_progress = True + break + + # if nothing fits right now, advance to the next release event + if not made_progress and events: + t, freed_ids = heapq.heappop(events) + current_time = t + for rid in freed_ids: + heapq.heappush(free_resources, rid) + elif not events: + # should not happen when cp_size ≥ max(cp_size_per_sample) + break + + # TODO: Return per GPU count of sub-samples to be processed instead of start_time + return start_time, assignment ###################### ### NVTX profiling ### From 0c5fa3937a6f552a3a651b116e989c7e91fabceb Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 17 Jul 2025 12:53:08 -0700 Subject: [PATCH 04/81] Return sub-sample per CP rank count to schedules --- megatron/core/pipeline_parallel/schedules.py | 6 +-- megatron/core/utils.py | 40 ++++++++++---------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index d4f52320047..6b965bdefa6 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -468,7 +468,7 @@ def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): break return current_cp_assignment -def heterogeneous_context_parallel(single_forward_step, total_num_tokens): +def heterogeneous_context_parallel(single_forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type): def forward_func_wrapper(*args, **kwargs): rank = torch.distributed.get_rank() # TODO: Get the correct rank based on process groups original_data_iterator = args.data_iterator @@ -486,7 +486,7 @@ def forward_func_wrapper(*args, **kwargs): # Run the N-1 backward steps, N-1 forward steps. # We will be left with Nth backward step after this loop which is run in the original function. for i in range(1, N): - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, args.config) current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) data["cp_assignment"] = current_cp_assignment @@ -544,7 +544,7 @@ def forward_backward_no_pipelining( input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") if config.heterogeneous_context_parallel: - forward_step = heterogeneous_context_parallel(forward_step, total_num_tokens) + forward_step = heterogeneous_context_parallel(forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type) with no_sync_func(): for i in range(num_microbatches - 1): diff --git a/megatron/core/utils.py b/megatron/core/utils.py index ad82669b570..14c3e6a829f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1863,10 +1863,10 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): return batch -def get_total_workload(seq_length: int): - return seq_length*seq_length +def get_total_workload(seq_length: int, cp_size: int): + return (seq_length*seq_length)/cp_size -def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_rank: int, cp_size: int, flops_calculator: Optional[Callable] = get_total_workload): +def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_rank: int, cp_size: int, compute_estimator: Optional[Callable] = get_total_workload): """ cu_seqlens: list of sub-sample sequence lengths max_seqlen_per_cp_rank: list of max sequence length per CP rank @@ -1877,36 +1877,39 @@ def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_ran start_time[j]: the time job j begins assignment[j]: list of resource IDs assigned to job j """ + cp_rank = torch.distributed.get_rank() # Get rank from CP group cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in cu_seqlens] - total_workload_per_sample = [flops_calculator(x) for x in cu_seqlens] - total_workload_per_cp_rank = [work/cp_size for work, cp in zip(total_workload_per_sample, cp_size_per_sample)] - n = len(total_workload_per_cp_rank) - # jobs in descending p[j] - jobs = sorted(range(n), key=lambda j: total_workload_per_cp_rank[j], reverse=True) - current_time = 0 + total_workload_per_cp_rank = [compute_estimator(x, cp) for x, cp in zip(cu_seqlens, cp_size_per_sample)] + # Sort workloads in descending order + num_sub_samples = len(cu_seqlens) + jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) - # a min-heap of free resource IDs + # a min-heap of free resource IDs (CP rank IDs) free_resources = list(range(cp_size)) heapq.heapify(free_resources) # events: (release_time, [list of resource IDs freeing then]) events = [] - - start_time = [None] * n - assignment = [None] * n - + # Trackers used in scheduling algorithm + current_time = 0 + start_time = [None] * num_sub_samples + assignment = [None] * num_sub_samples + num_sub_samples_processed = 0 + while jobs: made_progress = True - # try to schedule any job that fits in the currently free resources + # try to schedule any sub-sample that fits in the currently free resources while made_progress: made_progress = False for j in list(jobs): if cp_size_per_sample[j] <= len(free_resources): - # grab the lowest‐ID resources available + # grab the lowest‐ID CP ranks available assigned = [heapq.heappop(free_resources) for _ in range(cp_size_per_sample[j])] + if cp_rank in assigned: + num_sub_samples_processed += 1 start_time[j] = current_time assignment[j] = assigned - # schedule their release + # schedule the completion of the sub-sample compute release_time = current_time + total_workload_per_cp_rank[j] heapq.heappush(events, (release_time, assigned)) jobs.remove(j) @@ -1923,8 +1926,7 @@ def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_ran # should not happen when cp_size ≥ max(cp_size_per_sample) break - # TODO: Return per GPU count of sub-samples to be processed instead of start_time - return start_time, assignment + return num_sub_samples_processed, assignment ###################### ### NVTX profiling ### From 8546b0ff873169a45c8ab7cb22addfb21bcb6d0c Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 17 Jul 2025 17:54:05 -0700 Subject: [PATCH 05/81] Move wrapper to utils and formatting --- megatron/core/pipeline_parallel/schedules.py | 58 ++--------- megatron/core/utils.py | 103 ++++++++++++++++++- 2 files changed, 105 insertions(+), 56 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 6b965bdefa6..c4eb3891797 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -19,6 +19,7 @@ get_model_config, get_model_type, get_model_xattn, + heterogeneous_context_parallel, nvtx_range_pop, nvtx_range_push, ) @@ -446,54 +447,6 @@ def check_first_val_step(first_val_step, forward_only, cond): else: return cond -def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): - ''' - complete_cp_assignment is a list of lists, - Each inner list contains the cp_assignment (assigned GPU ranks) for a sub-sample - This function returns the ith sub-sample assigned to a GPU, None otherwise - For example, complete_cp_assignment = [[0, 1, 2, 3], [4, 5], [4, 5], [6, 7], [6, 7]] - For microbatch_id = 0; rank = 4; current_cp_assignment is [None, [4, 5], None, None, None] - This informs rank 4 that it should pick-up the 2nd sub-sample and share with rank 5 - For microbatch_id = 1; rank = 4; current_cp_assignment is [None, None, [4, 5], None, None] - This informs rank 4 that it should pick-up the 3rd sub-sample and share with rank 5 - ''' - current_cp_assignment = [None] * len(complete_cp_assignment) - matched_sample = -1 - index = None - for i, assigned_ranks in enumerate(complete_cp_assignment): - if rank in assigned_ranks: - matched_sample += 1 - if matched_sample == microbatch_id: - current_cp_assignment[i] = assigned_ranks - break - return current_cp_assignment - -def heterogeneous_context_parallel(single_forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type): - def forward_func_wrapper(*args, **kwargs): - rank = torch.distributed.get_rank() # TODO: Get the correct rank based on process groups - original_data_iterator = args.data_iterator - data = next(original_data_iterator) # TODO: Protect for model parallelism - # calculate new loop count - assert hasattr(data, "cu_seqlens"), "data must have a cu_seqlens attribute" - # TODO: N, complete_cp_assignment = get_heterogeneous_cp_assignment(data["cu_seqlens"], args.config.max_seqlen_per_cp_rank, args.config.context_parallel_size) - N=4 - current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0) - data["cp_assignment"] = current_cp_assignment - args.data_iterator = RerunDataIterator(iter([data])) - # Run the 1st micro-microbatch - output_tensor, num_tokens = single_forward_step(*args, **kwargs) - total_num_tokens += num_tokens - # Run the N-1 backward steps, N-1 forward steps. - # We will be left with Nth backward step after this loop which is run in the original function. - for i in range(1, N): - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, args.config) - - current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) - data["cp_assignment"] = current_cp_assignment - args.data_iterator = RerunDataIterator(iter([data])) - output_tensor, num_tokens = single_forward_step(*args, **kwargs) - total_num_tokens += num_tokens - return output_tensor, total_num_tokens def forward_backward_no_pipelining( *, @@ -544,7 +497,9 @@ def forward_backward_no_pipelining( input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") if config.heterogeneous_context_parallel: - forward_step = heterogeneous_context_parallel(forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type) + forward_step = heterogeneous_context_parallel( + forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type + ) with no_sync_func(): for i in range(num_microbatches - 1): @@ -565,8 +520,9 @@ def forward_backward_no_pipelining( backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) # TODO: Does the last complete microbatch should run with sync? - # What will this do when we have multiple micro-microbatches within a forward and we ideally want only sync at end? - + # What will this do when we have multiple micro-microbatches within a + # forward and we ideally want only sync at end? + # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor, num_tokens = forward_step( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 14c3e6a829f..1e7f25dc038 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1863,10 +1863,28 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): return batch + def get_total_workload(seq_length: int, cp_size: int): - return (seq_length*seq_length)/cp_size + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: + workload: workload of a sub-sample + """ + return (seq_length * seq_length) / cp_size + -def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_rank: int, cp_size: int, compute_estimator: Optional[Callable] = get_total_workload): +def get_heterogeneous_cp_assignment( + cu_seqlens: List[int], + max_seqlen_per_cp_rank: int, + cp_size: int, + compute_estimator: Optional[Callable] = get_total_workload, +): """ cu_seqlens: list of sub-sample sequence lengths max_seqlen_per_cp_rank: list of max sequence length per CP rank @@ -1877,9 +1895,11 @@ def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_ran start_time[j]: the time job j begins assignment[j]: list of resource IDs assigned to job j """ - cp_rank = torch.distributed.get_rank() # Get rank from CP group + cp_rank = torch.distributed.get_rank() # Get rank from CP group cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in cu_seqlens] - total_workload_per_cp_rank = [compute_estimator(x, cp) for x, cp in zip(cu_seqlens, cp_size_per_sample)] + total_workload_per_cp_rank = [ + compute_estimator(x, cp) for x, cp in zip(cu_seqlens, cp_size_per_sample) + ] # Sort workloads in descending order num_sub_samples = len(cu_seqlens) jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) @@ -1895,7 +1915,7 @@ def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_ran start_time = [None] * num_sub_samples assignment = [None] * num_sub_samples num_sub_samples_processed = 0 - + while jobs: made_progress = True # try to schedule any sub-sample that fits in the currently free resources @@ -1928,6 +1948,79 @@ def get_heterogeneous_cp_assignment(cu_seqlens: List[int], max_seqlen_per_cp_ran return num_sub_samples_processed, assignment + +def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): + ''' + complete_cp_assignment is a list of lists, + Each inner list contains the cp_assignment (assigned GPU ranks) for a sub-sample + This function returns the ith sub-sample assigned to a GPU, None otherwise + For example, complete_cp_assignment = [[0, 1, 2, 3], [4, 5], [4, 5], [6, 7], [6, 7]] + For microbatch_id = 0; rank = 4; current_cp_assignment is [None, [4, 5], None, None, None] + This informs rank 4 that it should pick-up the 2nd sub-sample and share with rank 5 + For microbatch_id = 1; rank = 4; current_cp_assignment is [None, None, [4, 5], None, None] + This informs rank 4 that it should pick-up the 3rd sub-sample and share with rank 5 + ''' + current_cp_assignment = [None] * len(complete_cp_assignment) + matched_sample = -1 + index = None + for i, assigned_ranks in enumerate(complete_cp_assignment): + if rank in assigned_ranks: + matched_sample += 1 + if matched_sample == microbatch_id: + current_cp_assignment[i] = assigned_ranks + break + return current_cp_assignment + + +def heterogeneous_context_parallel( + single_forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type +): + """ + Heterogeneous context parallel is a technique to balance the workload + of each CP rank when we use packed samples with variable sequence lengths. + This provides a wrapper function that replaces the 1 forward step of the + original microbatch with N forward + N-1 backward steps where N is the number + of sub-samples assigned to this CP rank. + """ + + def forward_func_wrapper(*args, **kwargs): + rank = parallel_state.get_context_parallel_rank() + original_data_iterator = args.data_iterator + data = next(original_data_iterator) # TODO: Protect for model parallelism + # calculate new loop count + assert hasattr(data, "cu_seqlens"), ( + "data must have a cu_seqlens attribute to define the valid sequenece lengths " + "of each sub-sample in a packed sample to use heterogeneous context parallel" + ) + # num_subsamples: number of sub-samples assigned to this CP rank + # complete_cp_assignment: list of lists, inner list CP ranks assigned to a sub-sample + num_subsamples, complete_cp_assignment = get_heterogeneous_cp_assignment( + data["cu_seqlens"], + args.config.max_seqlen_per_cp_rank, + args.config.context_parallel_size, + ) + # current_cp_assignment: list of lists, each inner list contains the + # CP ranks assigned to the sub-samples that are executing in the forward-backward loop. + # See function get_current_cp_assignment for more details. + current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0, rank) + data["cp_assignment"] = current_cp_assignment + args.data_iterator = RerunDataIterator(iter([data])) + # Run the 1st micro-microbatch + output_tensor, num_tokens = single_forward_step(*args, **kwargs) + total_num_tokens += num_tokens + # Run the N-1 backward steps, N-1 forward steps. + # Will be left with Nth backward step after this loop which is run in the original function. + for i in range(1, num_subsamples): + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, args.config) + + current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) + data["cp_assignment"] = current_cp_assignment + args.data_iterator = RerunDataIterator(iter([data])) + output_tensor, num_tokens = single_forward_step(*args, **kwargs) + total_num_tokens += num_tokens + return output_tensor, total_num_tokens + + ###################### ### NVTX profiling ### ###################### From b17a2a21cd2a91be4d210853e14e8db7abd5c1c3 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 18 Jul 2025 12:15:27 -0700 Subject: [PATCH 06/81] Add TransformerConfig and arguments --- megatron/core/model_parallel_config.py | 13 +++++++++++++ megatron/training/arguments.py | 7 +++++++ 2 files changed, 20 insertions(+) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index fc0041a0a70..3b501e076dc 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -52,6 +52,19 @@ class ModelParallelConfig: type. """ + #TODO: Should we calculate this using total seq length / cp size automatically? + max_seqlen_per_cp_rank: Optional[int] = None + """ + Maximum sequence length per CP rank. This is used to calculate the number of sub-samples + assigned to each CP rank when using heterogeneous context parallel. + """ + + heterogeneous_context_parallel: bool = False + """ + If true, enables heterogeneous context parallel. This is used to balance the workload of + each CP rank when we use packed samples with variable sequence lengths. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 35dc99b2858..4a4b8c56053 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2406,6 +2406,13 @@ def _add_distributed_args(parser): '--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus ' 'forms the first level of cp groups and the cp ranks with the same odevity ' 'forms the second level of cp groups.') + group.add_argument('--max-seqlen-per-cp-rank', type=int, default=None, + help='Maximum sequence length per CP rank. This is used to calculate the ' + 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') + group.add_argument('--heterogeneous-context-parallel', action='store_true', default=False, + help='Enables heterogeneous context parallel. This is used to balance the workload ' + 'of each CP rank when we use packed samples with variable sequence lengths. Requires ' + '--max-seqlen-per-cp-rank to be set.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' From 98f592ad2cf0f3b438f662276c8cf15d1aa71661 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 18 Jul 2025 18:38:03 -0700 Subject: [PATCH 07/81] Fix implementation issues [WIP] --- megatron/core/pipeline_parallel/schedules.py | 14 ++++++--- megatron/core/utils.py | 33 ++++++++++++-------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index c4eb3891797..03fe1f8e85e 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -496,14 +496,18 @@ def forward_backward_no_pipelining( forward_data_store = [] input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") - if config.heterogeneous_context_parallel: - forward_step = heterogeneous_context_parallel( - forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type - ) + # Create a wrapper function that handles heterogeneous context parallel if needed + def forward_step_wrapper(*args, **kwargs): + if config.heterogeneous_context_parallel: + return heterogeneous_context_parallel( + forward_step, backward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type + )(*args, **kwargs) + else: + return forward_step(*args, **kwargs) with no_sync_func(): for i in range(num_microbatches - 1): - output_tensor, num_tokens = forward_step( + output_tensor, num_tokens = forward_step_wrapper( forward_step_func, data_iterator, model, diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 1e7f25dc038..41aed9eae20 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1895,13 +1895,14 @@ def get_heterogeneous_cp_assignment( start_time[j]: the time job j begins assignment[j]: list of resource IDs assigned to job j """ + sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] cp_rank = torch.distributed.get_rank() # Get rank from CP group - cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in cu_seqlens] + cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in sub_sample_lens] total_workload_per_cp_rank = [ - compute_estimator(x, cp) for x, cp in zip(cu_seqlens, cp_size_per_sample) + compute_estimator(x, cp) for x, cp in zip(sub_sample_lens, cp_size_per_sample) ] # Sort workloads in descending order - num_sub_samples = len(cu_seqlens) + num_sub_samples = len(sub_sample_lens) jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) # a min-heap of free resource IDs (CP rank IDs) @@ -1973,7 +1974,7 @@ def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): def heterogeneous_context_parallel( - single_forward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type + single_forward_step, backward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type ): """ Heterogeneous context parallel is a technique to balance the workload @@ -1984,9 +1985,14 @@ def heterogeneous_context_parallel( """ def forward_func_wrapper(*args, **kwargs): + nonlocal total_num_tokens rank = parallel_state.get_context_parallel_rank() - original_data_iterator = args.data_iterator + forward_signature = inspect.signature(single_forward_step) + bound_args = forward_signature.bind(*args, **kwargs) + bound_args.apply_defaults() + original_data_iterator = bound_args.arguments['data_iterator'] data = next(original_data_iterator) # TODO: Protect for model parallelism + config = bound_args.arguments['config'] # calculate new loop count assert hasattr(data, "cu_seqlens"), ( "data must have a cu_seqlens attribute to define the valid sequenece lengths " @@ -1996,29 +2002,30 @@ def forward_func_wrapper(*args, **kwargs): # complete_cp_assignment: list of lists, inner list CP ranks assigned to a sub-sample num_subsamples, complete_cp_assignment = get_heterogeneous_cp_assignment( data["cu_seqlens"], - args.config.max_seqlen_per_cp_rank, - args.config.context_parallel_size, + config.max_seqlen_per_cp_rank, + config.context_parallel_size, ) # current_cp_assignment: list of lists, each inner list contains the # CP ranks assigned to the sub-samples that are executing in the forward-backward loop. # See function get_current_cp_assignment for more details. current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0, rank) data["cp_assignment"] = current_cp_assignment - args.data_iterator = RerunDataIterator(iter([data])) + bound_args.arguments['data_iterator'] = RerunDataIterator(iter([data])) # Run the 1st micro-microbatch - output_tensor, num_tokens = single_forward_step(*args, **kwargs) + output_tensor, num_tokens = single_forward_step(*bound_args.args, **bound_args.kwargs) total_num_tokens += num_tokens # Run the N-1 backward steps, N-1 forward steps. # Will be left with Nth backward step after this loop which is run in the original function. for i in range(1, num_subsamples): - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, args.config) + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i) + current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i, rank) data["cp_assignment"] = current_cp_assignment - args.data_iterator = RerunDataIterator(iter([data])) - output_tensor, num_tokens = single_forward_step(*args, **kwargs) + bound_args.arguments['data_iterator'] = RerunDataIterator(iter([data])) + output_tensor, num_tokens = single_forward_step(*bound_args.args, **bound_args.kwargs) total_num_tokens += num_tokens return output_tensor, total_num_tokens + return forward_func_wrapper ###################### From 552bd278acee384de441d1e948f393f6aa0c180e Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 21 Jul 2025 15:32:10 -0700 Subject: [PATCH 08/81] Add sharding logic and packed_seq_params attribute for TE support --- megatron/core/packed_seq_params.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index dff0cc59925..d7835038c65 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from torch import Tensor +from typing import List @dataclass @@ -18,3 +19,4 @@ class PackedSeqParams: cu_seqlens_kv_padded: Tensor = None max_seqlen_q: Tensor = None max_seqlen_kv: Tensor = None + cp_assignment: List[int] = None From 48913d00e375118494f55b5873ef75f851db7474 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 21 Jul 2025 15:33:00 -0700 Subject: [PATCH 09/81] Protect attention shape change logic with THD assumption --- megatron/core/transformer/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 47ec1ff0626..8934ef615c8 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -662,7 +662,7 @@ def forward( ) ) - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': query = query.squeeze(1) key = key.squeeze(1) value = value.squeeze(1) From 056055c22e8c290fd346a6bd08c35e5bbc56a73a Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 22 Jul 2025 16:59:00 -0700 Subject: [PATCH 10/81] Add sharding logic and TODO for sync across multiple microbatches --- megatron/core/utils.py | 47 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 41aed9eae20..cffc4beb29f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1831,7 +1831,7 @@ def is_submodule(module, parent_module, strict=True): ######################## -def get_batch_on_this_cp_rank(batch: Dict[str, Any]): +def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = None, cp_rank: Optional[int] = None): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ @@ -1842,12 +1842,14 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so # that we can get balanced workload among GPUs in a context parallel group. - cp_size = parallel_state.get_context_parallel_world_size() - if cp_size > 1: + if cp_size is None: + cp_size = parallel_state.get_context_parallel_world_size() + if cp_rank is None: cp_rank = parallel_state.get_context_parallel_rank() + if cp_size > 1: for key, val in batch.items(): if val is not None: - seq_dim = 1 if key != "attention_mask" else 2 + seq_dim = 1 if key != 'attention_mask' else 2 val = val.view( *val.shape[0:seq_dim], 2 * cp_size, @@ -2000,6 +2002,16 @@ def forward_func_wrapper(*args, **kwargs): ) # num_subsamples: number of sub-samples assigned to this CP rank # complete_cp_assignment: list of lists, inner list CP ranks assigned to a sub-sample + # TODO: When some ranks finish with their assigned sub-samples, they get the next microbatch from the data iterator. + # They then call heterogeneous_cp_assignment again. But some ranks are still executing their previous sub-samples. + # The scheduling algorithm assumes that all ranks have empty queues. + # Need to preserve the state of the queue. It should only empty at the end of global batch. + # OR should I force sync at the end of each microbatch across CP ranks? + # I think forcing sync at the end of each microbatch across CP ranks is the easier option. + # Since we only see 1 microbatch at a time, a later microbatch can have more GPUs for a sample than the 1st microbatch groups. + # For example, if microbatch 0 had group of 4 GPUs running for a sub-sample but microbatch 1 requires 6 GPUs + # If GPU 4,5 are done before 0,1,2,3 in microbatch 0, then they will be waiting for 0,1,2,3 to catch up which can lead to deadlock in comms. + # Unless we can stop 4,5 from executing which means partial syncs for each executing sub-sample. num_subsamples, complete_cp_assignment = get_heterogeneous_cp_assignment( data["cu_seqlens"], config.max_seqlen_per_cp_rank, @@ -2010,6 +2022,8 @@ def forward_func_wrapper(*args, **kwargs): # See function get_current_cp_assignment for more details. current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0, rank) data["cp_assignment"] = current_cp_assignment + # TODO: Make this a tensor so that tensor parallel broadcast works + # data["cp_assignment"] = torch.tensor(current_cp_assignment, dtype=torch.int32) bound_args.arguments['data_iterator'] = RerunDataIterator(iter([data])) # Run the 1st micro-microbatch output_tensor, num_tokens = single_forward_step(*bound_args.args, **bound_args.kwargs) @@ -2027,6 +2041,31 @@ def forward_func_wrapper(*args, **kwargs): return output_tensor, total_num_tokens return forward_func_wrapper +def get_sub_sample_on_this_cp_rank(batch, current_cp_assignment, packed_seq_params, cp_rank): + cu_lengths = packed_seq_params.cu_seqlens_q_padded + for i in range(len(current_cp_assignment)): + if current_cp_assignment[i] is not None: + assert cp_rank in current_cp_assignment[i], f"Current cp rank {cp_rank} is not part of the cp_assignment {current_cp_assignment[i]} given to this GPU" + start_index = cu_lengths[i] + end_index = cu_lengths[i+1] + cp_shard_ranks = current_cp_assignment[i] + break + for key, data in batch.items(): + batch[key] = data[:, start_index:end_index] + + # TODO: Clean this up. Reduce code by calculating indices once + sub_sample_packed_seq_params = PackedSeqParams( + qkv_format="sbhd", + cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[i+1] - packed_seq_params.cu_seqlens_q[i]], device="cpu", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[i+1] - packed_seq_params.cu_seqlens_kv[i]], device="cpu", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[i+1] - packed_seq_params.cu_seqlens_q_padded[i]], device="cpu", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[i+1] - packed_seq_params.cu_seqlens_kv_padded[i]], device="cpu", pin_memory=True), + max_seqlen_q=torch.tensor([end_index - start_index], device="cpu", pin_memory=True), + max_seqlen_kv=torch.tensor([end_index - start_index], device="cpu", pin_memory=True), + cp_assignment=cp_shard_ranks, + ) + return batch, cp_shard_ranks, sub_sample_packed_seq_params + ###################### ### NVTX profiling ### From 8d343696b2c35f43b2ddc0c63bfb67a84fce2ee4 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 13 Aug 2025 18:38:27 -0700 Subject: [PATCH 11/81] Updates to working milestone 1 Hybrid CP scheduling --- .../core/extensions/transformer_engine.py | 12 + megatron/core/model_parallel_config.py | 9 +- .../common/embeddings/rotary_pos_embedding.py | 34 +- megatron/core/models/gpt/gpt_model.py | 3 +- megatron/core/packed_seq_params.py | 4 + megatron/core/parallel_state.py | 36 + .../pipeline_parallel/hybrid_cp_schedule.py | 695 ++++++++++++++++++ megatron/core/pipeline_parallel/schedules.py | 24 +- megatron/core/transformer/attention.py | 2 +- megatron/core/utils.py | 225 +----- megatron/training/arguments.py | 8 +- megatron/training/initialize.py | 1 + 12 files changed, 839 insertions(+), 214 deletions(-) create mode 100644 megatron/core/pipeline_parallel/hybrid_cp_schedule.py diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 8646313a62f..81002e357f0 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -884,6 +884,18 @@ def forward( packed_seq_params: PackedSeqParams = None, ): """Forward.""" + if packed_seq_params is not None: + if packed_seq_params.cp_group is not None: + self.cp_group = packed_seq_params.cp_group + super().set_context_parallel_group( + self.cp_group, + torch.distributed.get_process_group_ranks(self.cp_group), + super().cp_stream, + ) + elif packed_seq_params.local_cp_size is not None: + super().set_context_parallel_group(None, None, None) + self.kept_packed_seq_params.discard("cp_group") + self.kept_packed_seq_params.discard("local_cp_size") packed_seq_kwargs = ( {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} if packed_seq_params is not None diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 7636bcb5872..cc47e1876ae 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -52,16 +52,15 @@ class ModelParallelConfig: type. """ - #TODO: Should we calculate this using total seq length / cp size automatically? max_seqlen_per_cp_rank: Optional[int] = None """ - Maximum sequence length per CP rank. This is used to calculate the number of sub-samples - assigned to each CP rank when using heterogeneous context parallel. + Maximum sequence length per CP rank. This is used to calculate the number and length + of sub-samples assigned to each CP rank when using hybrid data x context parallel. """ - heterogeneous_context_parallel: bool = False + hybrid_context_parallel: bool = False """ - If true, enables heterogeneous context parallel. This is used to balance the workload of + If true, enables hybrid context parallel. This is used to balance the workload of each CP rank when we use packed samples with variable sequence lengths. """ diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 0d7d5e626d0..72b55caf9ff 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -148,13 +148,12 @@ def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): return cos, sin @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: - """Forward pass of RoPE embedding. + def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Forward pass of RoPE embedding before CP sharding. Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. Returns: Tensor: Embeddings after applying RoPE. @@ -174,10 +173,33 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) - ) # emb [seq_length, .., dim] emb = emb[:, None, None, :] + return emb + + def forward(self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: + """Forward pass of RoPE embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + emb = self.get_emb(max_seq_len, offset) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + if packed_seq_params.local_cp_size is None: + cp_group = self.cp_group + elif packed_seq_params.local_cp_size > 1: + cp_group = packed_seq_params.cp_group + else: + cp_group = None + if cp_group is not None: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) + return emb def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 62b22cb5a69..d34969b8a40 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -298,8 +298,7 @@ def _preprocess( ) rotary_pos_emb = self.rotary_pos_emb( rotary_seq_len, - packed_seq=packed_seq_params is not None - and packed_seq_params.qkv_format == 'thd', + packed_seq_params=packed_seq_params, ) elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: if self.training or not self.config.flash_decode: diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index 65ad442ae3e..b93cc970ffb 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -3,6 +3,7 @@ from torch import Tensor from typing import List +import torch.distributed as dist @dataclass @@ -20,3 +21,6 @@ class PackedSeqParams: max_seqlen_q: int = None max_seqlen_kv: int = None cp_assignment: List[int] = None + local_cp_size: int = None + schedule_id: int = None + cp_group: dist.ProcessGroup = None diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 46400c390e5..526a873662d 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -105,6 +105,8 @@ _CONTEXT_PARALLEL_GLOBAL_RANKS = None # Hierarchical context parallel groups _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = None +# Hybrid context parallel groups +_HYBRID_CP_GROUPS = {} # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None @@ -363,6 +365,22 @@ def create_hierarchical_groups( assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) return hierarchical_groups, hierarchical_groups_gloo +def create_hybrid_cp_groups(rank, ranks, pg_options): + hybrid_cp_groups = {} + # Generate group for every power of 2 up to the number of CP ranks + # We limit the allowed group sizes in order to avoid excessive overhead. + group_sizes = [2**i for i in range(int(log2(len(ranks))))][1:] + for group_size in group_sizes: + for i in range(0, len(ranks), group_size): + group = create_group( + ranks[i:i+group_size], + pg_options=pg_options, + group_desc=f"HYBRID_CP_GROUP_{group_size}", + ) + if rank in ranks[i:i+group_size]: + assert group_size not in hybrid_cp_groups, f"Rank {rank} appears in multiple Hybrid CP groups of size {group_size}" + hybrid_cp_groups[group_size] = group + return hybrid_cp_groups class RankGenerator(object): """A class for generating rank groups for different modes of parallelism.""" @@ -473,6 +491,7 @@ def initialize_model_parallel( use_sharp: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, + hybrid_context_parallel: bool = False, expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, @@ -846,6 +865,13 @@ def initialize_model_parallel( if rank in ranks: _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups + if hybrid_context_parallel: + assert len(ranks) % 2 == 0, "Hybrid context parallel requires an even number of ranks" + global _HYBRID_CP_GROUPS + if rank in ranks: + _HYBRID_CP_GROUPS.update(create_hybrid_cp_groups(rank, ranks, get_nccl_options("cp", nccl_comm_cfgs))) + #TODO: Are gloo groups needed for hybrid cp? + # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GLOBAL_RANKS @@ -1295,6 +1321,16 @@ def get_hierarchical_context_parallel_groups(check_initialized=True): assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS +def get_hybrid_context_parallel_groups(check_initialized=True, group_size=None): + """Get the hybrid context parallel groups the caller rank belongs to.""" + # If the group size is the same as the entire DPxCP group, return the original group + if get_context_parallel_world_size() == group_size: + if check_initialized: + assert _CONTEXT_PARALLEL_GROUP is not None + return _CONTEXT_PARALLEL_GROUP + if check_initialized: + assert _HYBRID_CP_GROUPS is not None + return _HYBRID_CP_GROUPS[group_size] def get_embedding_group(check_initialized=True): """Get the embedding group the caller rank belongs to.""" diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py new file mode 100644 index 00000000000..8ee6ec02b32 --- /dev/null +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -0,0 +1,695 @@ +import torch +from megatron.core import parallel_state +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.rerun_state_machine import RerunDataIterator +from typing import Any, Callable, Dict, Deque, List, Optional, Tuple, Type, Union +from functools import lru_cache +import math +import heapq + +class BalancedCPScheduler: + def __init__(self, max_seq_len_per_rank: int): + self.max_seq_len_per_rank = max_seq_len_per_rank + self.num_subsamples = 0 + self.num_subsamples_processed = 0 + self.free_resources = [] + + @lru_cache(maxsize=128) + def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: + workload: workload of a sub-sample + """ + if cp_size is None: + cp_size = self.gpus_needed(seq_length) + return (seq_length * seq_length) / cp_size + + def get_heterogeneous_cp_assignment( + self, + cu_seqlens: List[int], + max_seqlen_per_cp_rank: int, + cp_size: int, + compute_estimator: Optional[Callable] = None, + ): + """ + cu_seqlens: list of sub-sample sequence lengths + max_seqlen_per_cp_rank: list of max sequence length per CP rank + cp_size: total number of CP ranks + flops_calculator: function to calculate flops from cu_seqlens + + Returns: + start_time[j]: the time job j begins + assignment[j]: list of resource IDs assigned to job j + """ + if compute_estimator is None: + compute_estimator = self.get_total_workload + #TODO: Account for nvlink - IBlink boundaries. How to dynamically get this information? + #TODO: is cu_seqlens a list of ints or tensors? Correct the type hint + sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] + cp_rank = torch.distributed.get_rank() # Get rank from CP group + cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in sub_sample_lens] + total_workload_per_cp_rank = [ + compute_estimator(x, cp) for x, cp in zip(sub_sample_lens, cp_size_per_sample) + ] + # Sort workloads in descending order + num_sub_samples = len(sub_sample_lens) + jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) + + # a min-heap of free resource IDs (CP rank IDs) + if len(self.free_resources) == 0: + self.free_resources = list(range(cp_size)) + heapq.heapify(self.free_resources) + + # events: (release_time, [list of resource IDs freeing then]) + events = [] + # Trackers used in scheduling algorithm + current_time = 0 + start_time = [None] * num_sub_samples + assignment = [None] * num_sub_samples + num_sub_samples_processed = 0 + + while jobs: + made_progress = True + # try to schedule any sub-sample that fits in the currently free resources + while made_progress: + made_progress = False + for j in list(jobs): + if cp_size_per_sample[j] <= len(self.free_resources): + # grab the lowest‐ID CP ranks available + assigned = [heapq.heappop(self.free_resources) for _ in range(cp_size_per_sample[j])] + if cp_rank in assigned: + num_sub_samples_processed += 1 + start_time[j] = current_time + assignment[j] = assigned + # schedule the completion of the sub-sample compute + release_time = current_time + total_workload_per_cp_rank[j] + heapq.heappush(events, (release_time, assigned)) + jobs.remove(j) + made_progress = True + break + + # if nothing fits right now, advance to the next release event + if not made_progress and events: + t, freed_ids = heapq.heappop(events) + current_time = t + for rid in freed_ids: + heapq.heappush(self.free_resources, rid) + elif not events: + # should not happen when cp_size ≥ max(cp_size_per_sample) + break + + return num_sub_samples_processed, assignment + + # def get_per_microbatch_assignment( + # self, + # cu_seqlens: List[int], + # cp_size: int, + # compute_estimator: Optional[Callable] = None, + # ): + # """ + # cu_seqlens: list of sub-sample sequence lengths + # cp_size: total number of CP ranks + # compute_estimator: function to calculate flops from cu_seqlens + # """ + # micro_batches = [] + # exec_times = [] + # if compute_estimator is None: + # compute_estimator = self.get_total_workload + # sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] + # while sub_sample_lens: + # sub_sample_lens = sorted(sub_sample_lens, reverse=True) + # microbatch, sub_sample_lens, exec_times = self.next_hdp_group(sub_sample_lens, compute_estimator, cp_size) + # micro_batches.append(microbatch) + # exec_times.append(exec_times) + # return micro_batches + + @lru_cache(maxsize=128) + def gpus_needed(self, seq_len: int) -> int: + return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) + + def make_buckets_equal( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + ) -> List[deque]: + """ + Modified version of make_buckets_equal_work that works with (sample_id, seq_len) tuples. + This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. + """ + # Extract just the sequence lengths for determining k + seqlens = [seq_len for _, seq_len in sample_seqlens] + + # Determine k based on unique GPU categories needed + k = len({self.gpus_needed(L) for L in seqlens}) + + # Use the existing contiguous_equal_buckets function but with sample_seqlens + # We need to modify it to work with tuples + work = [] + for _, s in sample_seqlens: + cp_size = self.gpus_needed(s) + work.append(compute_estimator(s, cp_size)) + total_work = sum(work) + target = total_work / k + buckets, cur, cur_work = [], [], 0.0 + remaining_work = total_work + remaining_k = k + + for i, (sample_id, seq_len) in enumerate(sample_seqlens): + work = compute_estimator(seq_len) + projected = cur_work + work + + # Check if we should close this bucket + if (cur and + (projected > target * 1.1 or # Too much work + len(sample_seqlens) - i <= remaining_k - len(buckets))): # Need to save sequences for remaining buckets + buckets.append(deque(cur)) + cur, cur_work = [], 0.0 + remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) + remaining_k -= 1 + + cur.append((sample_id, seq_len)) + cur_work += work + + if cur: + buckets.append(deque(cur)) + + return buckets + + def next_hdp_group( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + total_gpus: int, + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance + ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + """ + Given a list of (sample_id, sequence_length) tuples, this function aims to assign + sequences to a microbatch such that all GPUs in the CP domain have a roughly balanced workload. + Once each microbatch is roughly balanced, we exit and return the microbatch and the leftover sequences. + + The function performs the following passes in order to form a balanced microbatch: + 1. We create buckets of sequences that are roughly balanced. + We try to create as many buckets as possible CP sizes. + 2. Given a bucket has sequences available, we assign the microbatch + a. To a new set of GPUs if there are enough free GPUs. + b. To an existing set of GPUs with the lowest load. + 3. We check if the microbatch is balanced whenever we need to move onto a new CP size in the same set of GPUs. + 4. We trim the microbatch if removing the last added sequence helps improve balance. + 5. If we run out of sequences to assign and there are empty GPUs, + we redistribute work to empty GPUs by recursively increasing the CP size of a sample until no empty GPUs are left.. + + #TODO: Add clarification on when we check for balance. What does prev_needed do? + + Returns (*micro_batches*, *leftover_sample_seqlens*, *exec_times*, *sample_ids_per_gpu*). + """ + if not sample_seqlens: + return [[] for _ in range(total_gpus)], [], [0.0 for _ in range(total_gpus)], [[] for _ in range(total_gpus)] + + # Use the improved bucketing that works with (sample_id, seq_len) tuples + buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) + + # Initialize tracking structures + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] + sample_ids_per_gpu = [[] for _ in range(total_gpus)] + + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 + + pp_cursor = 0 + prev_needed = None + check_balance = False + + while buckets: + # ---- Step 1 – pick the next sequence we COULD place ------------------ + sample_seq_tuple = bucket_idx = None + needed = None + + scan_order = ( + range(len(buckets)) if strategy == "dp" + else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] + ) + + for idx in scan_order: + if not buckets[idx]: + continue + cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) + cand_seq_len = cand_tuple[1] + needed = self.gpus_needed(cand_seq_len) + + # (a) Do we have an *existing* group of size `needed`? + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + + # (b) Or enough completely free GPUs to start a new group? + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if candidate_gids or len(free_ranks) >= needed: + sample_seq_tuple, bucket_idx = cand_tuple, idx + break + + # No place to put any remaining sequence – finish this micro‑batch + if sample_seq_tuple is None: + break + + if strategy == "pp": + pp_cursor = (bucket_idx + 1) % len(buckets) + + sample_id, seq_len = sample_seq_tuple + needed = self.gpus_needed(seq_len) + if prev_needed is None: + prev_needed = needed + + # (a) Existing groups of exactly this size + candidate_gids = [ + gid for gid, sz in group_size.items() if sz == needed + ] + if candidate_gids: + best_gid, best_load = min( + ((gid, max(exec_times[r] for r in group_members[gid])) + for gid in candidate_gids), + key=lambda t: t[1] + ) + else: + best_gid, best_load = None, float("inf") + + # (b) Hypothetical **new** group from completely free GPUs + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if len(free_ranks) >= needed: + free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) + new_members = free_sorted[:needed] + new_load = exec_times[new_members[-1]] + + if new_load < best_load: + best_gid = None + chosen_members = new_members + else: + chosen_members = group_members[best_gid] + else: + if best_gid is None: + print(f"No room to form a new group") + chosen_members = group_members[best_gid] + + # ---- Step 2b – if we decided to create a fresh group ---------------- + if best_gid is None: + best_gid = next_gid + next_gid += 1 + group_members[best_gid] = chosen_members + group_size[best_gid] = needed + for r in chosen_members: + gpu_group_id[r] = best_gid + + # ---- Step 3 – assign the sequence to every member of that group ------ + per_gpu_cost = compute_estimator(seq_len) + + for r in chosen_members: + micro_batches[r].append(seq_len) + exec_times[r] += per_gpu_cost + sample_ids_per_gpu[r].append(sample_id) + + # Remove the sequence definitively from its bucket + buckets[bucket_idx].popleft() + + # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ + while buckets and not buckets[0]: + buckets.pop(0) + pp_cursor %= max(1, len(buckets)) + + # TODO: Should I pre-emptively break out if slack is already within delta? + # Feels like if we have global batch level samples, we will have lots with same CP size. + # So we can just keep adding samples. + # We already have trim workload to handle imbalanced cases. + # TODO: Removing this helps reduce the number of groups when we have lots of samples with same CP size. + # But because we don't exit as soon as we get balanced, even if there is one group available that can take the next sample, + # we will keep adding samples to the same group. + # trim_overload() does not help because it only checks if removing the last added sample helps. + # We cannot check after adding every sample because there will always be imbalance if we don't wait for future scheduling. + + # IMPORTANT: So we need a solution here + if needed < prev_needed: + # When we get into a lower CP size in the same group, we can start checking for balance. + # There is still a gotcha here. + # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. + # We keep assigning group of 2 as we do in descending order but GPU 7/15 never sees a microbatch assigned to it + # until we run out of samples with CP2. + # This means we are never balanced as min(exec_times) will always be 0. + # We need a smart way of identifying that we have run out of big samples and if we are having to + # assign work to a GPU already working, is it because there are empty GPUs? + # Would assigning work to empty GPUs first by moving onto next CP bucket help? + # But we need to remember to come back to this CP size bucket and then check for balance. + # Maybe the scheduling algorithm should look at empty GPUs and find work rather than going + # sequence by sequence. + check_balance = True + + if check_balance and buckets and max(exec_times) - min(exec_times) <= delta * max(exec_times): + break + + # Gather leftovers (flatten remaining buckets, preserve order) + leftovers = [] + for b in buckets: + for sample_seq_tuple in b: + leftovers.append(sample_seq_tuple) + + # --------------------------------------------------------------------------- + def trim_overload(): + """ + Iteratively pop the most‑recent sequence from the *most‑loaded group* + whenever doing so reduces the global slack. + """ + while True: + cur_max = max(exec_times) + cur_min = min(exec_times) + cur_slack = cur_max - cur_min + if cur_slack <= delta * cur_max: + break + + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] + members = group_members[gid] + + if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: + break + + seq = micro_batches[max_r][-1] + need = group_size[gid] + per_gpu_cost = compute_estimator(seq) + + proj_times = exec_times[:] + for r in members: + proj_times[r] -= per_gpu_cost + + proj_slack = max(proj_times) - min(proj_times) + + if proj_slack < cur_slack: + sample_id_to_remove = sample_ids_per_gpu[max_r][-1] + for r in members: + micro_batches[r].pop() + exec_times[r] -= per_gpu_cost + sample_ids_per_gpu[r].pop() + leftovers.append((sample_id_to_remove, seq)) + else: + break + + trim_overload() + + # Track work before redistribution + total_work_before = sum(len(mb) for mb in micro_batches) + + # Check for empty GPUs and redistribute work + def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size): + """ + Recursively check for empty GPUs and redistribute work by increasing + the number of GPUs sharing samples. This ensures all GPUs have work. + GPUs must be allocated consecutively. + """ + # Find empty GPUs + empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] + if not empty_gpus: + return micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size # No empty GPUs, we're done + + # Find the smallest group size that exists + existing_group_sizes = set(group_size.values()) + if not existing_group_sizes: + return # No groups exist, cannot redistribute + + min_group_size = min(existing_group_sizes) + next_power = min_group_size * 2 + + # Find the first group of min_group_size that can be expanded + expandable_gid = None + expandable_members = None + expandable_new_gpus = None + + for gid, size in group_size.items(): + if size == min_group_size: + members = group_members[gid] + # get_new_work_queue(members[-1], min_group_size) + needed_count = min_group_size + current_gpu = members[-1] + empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] + assert not all(work for work in micro_batches[empty_gpu:empty_gpu+needed_count]), f"Not enough empty GPUs to expand or there are empty GPUs between work scheduled which is not allowed." + work_to_push = micro_batches[current_gpu + 1 : empty_gpu] # This is work of all other subsequent sub-samples + exec_times_to_push = exec_times[current_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[current_gpu + 1 : empty_gpu] + + + new_micro_batches = [[]] * len(micro_batches) + new_exec_times = [0.0] * len(exec_times) + new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) + + for i in range(current_gpu+1): + new_micro_batches[i] = micro_batches[i] + new_exec_times[i] = exec_times[i] + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] + + for i in range(needed_count): + new_micro_batches[current_gpu + 1 +i] = micro_batches[current_gpu] + new_exec_times[current_gpu + 1 + i] = exec_times[current_gpu] + new_sample_ids_per_gpu[current_gpu + 1 + i] = sample_ids_per_gpu[current_gpu] + + for i, work in enumerate(work_to_push): + new_micro_batches[current_gpu + needed_count + 1 + i] = work + new_exec_times[current_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[current_gpu + needed_count + 1 + i] = sample_ids_to_push[i] + + group_size[gid] = next_power + group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) + for pushed_gid in group_size.keys(): + if pushed_gid > gid: + group_members[pushed_gid] = [x + needed_count for x in group_members[pushed_gid]] + + return new_micro_batches, new_exec_times, new_sample_ids_per_gpu, group_members, group_size + + + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + while empty_gpus: + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size) + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + + # Assert that no work has been completely removed + total_work_after = sum(len(mb) for mb in micro_batches) + assert total_work_after >= total_work_before, f"Work was removed: {total_work_before} -> {total_work_after}" + + return micro_batches, leftovers, exec_times, sample_ids_per_gpu + + def get_groups_and_subsamples( + self, + data, + model, + config, + ): + # TODO: Protect for model parallelism + # TODO: Reduce access to file system as much as possible. + groups = [] + sample_id_groups = [] + assert "cu_lengths" in data, ( + "data must have a cu_seqlens attribute to define the valid sequenece lengths " + "of each sub-sample in a packed sample to use hybrid context parallel" + ) + # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. + # TODO (Milestone 2): Sample ID logic will have to change once we have global batch + sample_id_seqlens = [(i, int(data["cu_lengths"][0][i+1] - data["cu_lengths"][0][i])) for i in range(0, data["cu_lengths"][0].shape[0] - 1)] + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + while sample_id_seqlens: + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group(sample_id_seqlens, self.get_total_workload, config.context_parallel_size) + groups.append(mb) + if len(sample_ids) < config.context_parallel_size: + sample_ids.extend([] * (config.context_parallel_size - len(sample_ids))) + sample_id_groups.append(sample_ids) + # print(f"groups: {groups}") + # print(f"sample_id_groups: {sample_id_groups}") + + return groups, sample_id_groups + +def hybrid_context_parallel_forward_backward( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + output_tensor_grad, + forward_data_store, + config, + collect_non_loss_data, + first_val_step, + forward_only, + no_sync_func, + total_num_tokens, + check_first_val_step, + model_type, +): + """ + Scheduler for Hybrid Context Parallel. + + This function performs the packed sample scheduling and determines + 1. The number of microbatches to schedule for each CP rank + 2. The number of groups each CP rank should execute + 3. The number of sub-samples per group each CP rank should execute + + A group is defined by a set of samples that can run across the CP domain without any barrier. + There are many reasons why we may not be able to run endless number of samples within a single group. + For example, if we have 8 GPUs, + if GPU 0-5 are assigned a long sample that requires CP6, + GPU 6-7 are assigned a short sample that requires CP2, + The next sample which requires CP4 can be assigned GPU 4-7. + But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. + + As of now, the number of microbatches is pre-determined by GBS and DP size. + We perform the scheduling for each microbatch. + In the future, when we schedule over the entire global batch, we will remove the need for step #2 and + number of microbatches will be determined by the number of groups. + """ + from .schedules import forward_step, backward_step + + cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) + with no_sync_func(): + for i in range(num_microbatches - 1): + data = next(data_iterator) + groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + for j in range(len(groups)): + # Get sub-samples for the current CP rank + # TODO: Update to DPxCP rank when milestone 2 + sample_ids_per_group = sample_id_groups[j][parallel_state.get_context_parallel_rank()] + for k in range(len(sample_ids_per_group)): + # Call forward step for each sub-sample + sub_sample_id = sample_ids_per_group[k] + partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) + if partner_cp_size == 0: + assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" + data["local_cp_size"] = partner_cp_size + data["scheduled_id"] = sub_sample_id + new_data_iterator = RerunDataIterator(iter([data])) + # TODO: Change data iterator to the right sub-sample + # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), + current_microbatch=i, + ) + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # TODO: Move to DPxCP barrier + torch.distributed.barrier(parallel_state.get_context_parallel_group()) + + # Last microbatch + # TODO: Call scheduler here. + with no_sync_func(): + data = next(data_iterator) + groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + for j in range(len(groups) - 1): + sample_ids_per_group = sample_id_groups[j][parallel_state.get_context_parallel_rank()] + for k in range(len(sample_ids_per_group)): + # Call forward step for each sub-sample + sub_sample_id = sample_ids_per_group[k] + + partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) + if partner_cp_size == 0: + assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" + data["local_cp_size"] = partner_cp_size + data["scheduled_id"] = sub_sample_id + # TODO: What else should I update in data so that we can get the right sub-sample? + new_data_iterator = RerunDataIterator(iter([data])) + # TODO: Change data iterator to the right sub-sample + # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + current_microbatch=num_microbatches - 1, + ) + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # TODO: Move to DPxCP barrier + torch.distributed.barrier(parallel_state.get_context_parallel_group()) + + # For the last group, we need to run the last sub-sample out of the context handler. + # TODO: Find num sub-samples per group in this group + with no_sync_func(): + sample_ids_per_group = sample_id_groups[-1][parallel_state.get_context_parallel_rank()] + for k in range(len(sample_ids_per_group) - 1): + sub_sample_id = sample_ids_per_group[k] + partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) + data["local_cp_size"] = partner_cp_size + data["scheduled_id"] = sub_sample_id + # TODO: What else should I update in data so that we can get the right sub-sample? + new_data_iterator = RerunDataIterator(iter([data])) + # TODO: Change data iterator to the right sub-sample + # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + # Call forward step for each sub-sample + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + current_microbatch=num_microbatches - 1, + ) + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # The last sub-sample of the last group of the last microbatch is run out of the context handler. + sub_sample_id = sample_ids_per_group[-1] + partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) + if partner_cp_size == 0: + assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" + data["local_cp_size"] = partner_cp_size + data["scheduled_id"] = sub_sample_id + # TODO: What else should I update in data so that we can get the right sub-sample? + new_data_iterator = RerunDataIterator(iter([data])) + # TODO: Change data iterator to the right sub-sample + # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + # Call forward step for each sub-sample + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + current_microbatch=num_microbatches - 1, + ) + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # TODO: Move to DPxCP barrier + torch.distributed.barrier(parallel_state.get_context_parallel_group()) + + # TODO: Before returning forward_data_store, we need to change the loss. + # Instead of letting reporting loss be calculated by train_step, can we just calculate it here? + # Since we will need the global number of samples information to average the loss and local cp group size for each sample. + return forward_data_store, total_num_tokens \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 123841b0ccc..46eeaa30ef7 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -18,16 +18,12 @@ get_attr_wrapped_model, get_model_config, get_model_type, -<<<<<<< HEAD - get_model_xattn, - heterogeneous_context_parallel, -======= ->>>>>>> 410222b3ec5e05f53871e7371e02128eff38c7e8 nvtx_range_pop, nvtx_range_push, ) from .combined_1f1b import combined_1f1b_schedule_for_no_pipelining +from .hybrid_cp_schedule import hybrid_context_parallel_forward_backward # Types Shape = Union[List[int], torch.Size] @@ -522,6 +518,24 @@ def forward_backward_no_pipelining( total_num_tokens, partial(check_first_val_step, first_val_step, forward_only), ) + elif config.hybrid_context_parallel: + forward_data_store, total_num_tokens = hybrid_context_parallel_forward_backward( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + output_tensor_grad, + forward_data_store, + config, + collect_non_loss_data, + first_val_step, + forward_only, + no_sync_func, + total_num_tokens, + check_first_val_step, + model_type, + ) else: with no_sync_func(): for i in range(num_microbatches - 1): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 58b428c4a7d..e355fe9ef5b 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -674,7 +674,7 @@ def forward( if rotary_pos_emb is not None and not self.config.flash_decode: q_pos_emb, k_pos_emb = rotary_pos_emb - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded else: diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 0ed2f001e73..89dfdfdad16 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1842,6 +1842,9 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = No # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so # that we can get balanced workload among GPUs in a context parallel group. + if cp_size is not None or cp_rank is not None: + assert cp_size is not None and cp_rank is not None, "Both cp_size and cp_rank must be provided for batch slicing" + if cp_size is None: cp_size = parallel_state.get_context_parallel_world_size() if cp_rank is None: @@ -1865,206 +1868,46 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = No return batch - -def get_total_workload(seq_length: int, cp_size: int): - """ - seq_length: sequence length of a sub-sample - cp_size: total number of CP ranks working on this sub-sample - - Note: - This function is used to estimate the relative workload intensity - of a sub-sample. This is not meant to be an accurate flops calculator. - - Returns: - workload: workload of a sub-sample - """ - return (seq_length * seq_length) / cp_size - - -def get_heterogeneous_cp_assignment( - cu_seqlens: List[int], - max_seqlen_per_cp_rank: int, - cp_size: int, - compute_estimator: Optional[Callable] = get_total_workload, -): - """ - cu_seqlens: list of sub-sample sequence lengths - max_seqlen_per_cp_rank: list of max sequence length per CP rank - cp_size: total number of CP ranks - flops_calculator: function to calculate flops from cu_seqlens - - Returns: - start_time[j]: the time job j begins - assignment[j]: list of resource IDs assigned to job j +def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params): """ - sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] - cp_rank = torch.distributed.get_rank() # Get rank from CP group - cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in sub_sample_lens] - total_workload_per_cp_rank = [ - compute_estimator(x, cp) for x, cp in zip(sub_sample_lens, cp_size_per_sample) - ] - # Sort workloads in descending order - num_sub_samples = len(sub_sample_lens) - jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) - - # a min-heap of free resource IDs (CP rank IDs) - free_resources = list(range(cp_size)) - heapq.heapify(free_resources) - - # events: (release_time, [list of resource IDs freeing then]) - events = [] - # Trackers used in scheduling algorithm - current_time = 0 - start_time = [None] * num_sub_samples - assignment = [None] * num_sub_samples - num_sub_samples_processed = 0 - - while jobs: - made_progress = True - # try to schedule any sub-sample that fits in the currently free resources - while made_progress: - made_progress = False - for j in list(jobs): - if cp_size_per_sample[j] <= len(free_resources): - # grab the lowest‐ID CP ranks available - assigned = [heapq.heappop(free_resources) for _ in range(cp_size_per_sample[j])] - if cp_rank in assigned: - num_sub_samples_processed += 1 - start_time[j] = current_time - assignment[j] = assigned - # schedule the completion of the sub-sample compute - release_time = current_time + total_workload_per_cp_rank[j] - heapq.heappush(events, (release_time, assigned)) - jobs.remove(j) - made_progress = True - break - - # if nothing fits right now, advance to the next release event - if not made_progress and events: - t, freed_ids = heapq.heappop(events) - current_time = t - for rid in freed_ids: - heapq.heappush(free_resources, rid) - elif not events: - # should not happen when cp_size ≥ max(cp_size_per_sample) - break + For a packed sequence, this function returns + 1. The sub-sample of the sequence assigned to this CP rank. + 2. The appropriate CP group for the new CP assignment. + 3. The updated packed sequence parameters. - return num_sub_samples_processed, assignment - - -def get_current_cp_assignment(complete_cp_assignment, microbatch_id, rank): - ''' - complete_cp_assignment is a list of lists, - Each inner list contains the cp_assignment (assigned GPU ranks) for a sub-sample - This function returns the ith sub-sample assigned to a GPU, None otherwise - For example, complete_cp_assignment = [[0, 1, 2, 3], [4, 5], [4, 5], [6, 7], [6, 7]] - For microbatch_id = 0; rank = 4; current_cp_assignment is [None, [4, 5], None, None, None] - This informs rank 4 that it should pick-up the 2nd sub-sample and share with rank 5 - For microbatch_id = 1; rank = 4; current_cp_assignment is [None, None, [4, 5], None, None] - This informs rank 4 that it should pick-up the 3rd sub-sample and share with rank 5 - ''' - current_cp_assignment = [None] * len(complete_cp_assignment) - matched_sample = -1 - index = None - for i, assigned_ranks in enumerate(complete_cp_assignment): - if rank in assigned_ranks: - matched_sample += 1 - if matched_sample == microbatch_id: - current_cp_assignment[i] = assigned_ranks - break - return current_cp_assignment - - -def heterogeneous_context_parallel( - single_forward_step, backward_step, total_num_tokens, input_tensor, output_tensor_grad, model_type -): - """ - Heterogeneous context parallel is a technique to balance the workload - of each CP rank when we use packed samples with variable sequence lengths. - This provides a wrapper function that replaces the 1 forward step of the - original microbatch with N forward + N-1 backward steps where N is the number - of sub-samples assigned to this CP rank. + Args: + batch: The batch of data to slice. + scheduled_id: The index of the sub-sample to return. + local_cp_size: The size of the CP group. + packed_seq_params: The updated packed sequence parameters. """ - - def forward_func_wrapper(*args, **kwargs): - nonlocal total_num_tokens - rank = parallel_state.get_context_parallel_rank() - forward_signature = inspect.signature(single_forward_step) - bound_args = forward_signature.bind(*args, **kwargs) - bound_args.apply_defaults() - original_data_iterator = bound_args.arguments['data_iterator'] - data = next(original_data_iterator) # TODO: Protect for model parallelism - config = bound_args.arguments['config'] - # calculate new loop count - assert hasattr(data, "cu_seqlens"), ( - "data must have a cu_seqlens attribute to define the valid sequenece lengths " - "of each sub-sample in a packed sample to use heterogeneous context parallel" - ) - # num_subsamples: number of sub-samples assigned to this CP rank - # complete_cp_assignment: list of lists, inner list CP ranks assigned to a sub-sample - # TODO: When some ranks finish with their assigned sub-samples, they get the next microbatch from the data iterator. - # They then call heterogeneous_cp_assignment again. But some ranks are still executing their previous sub-samples. - # The scheduling algorithm assumes that all ranks have empty queues. - # Need to preserve the state of the queue. It should only empty at the end of global batch. - # OR should I force sync at the end of each microbatch across CP ranks? - # I think forcing sync at the end of each microbatch across CP ranks is the easier option. - # Since we only see 1 microbatch at a time, a later microbatch can have more GPUs for a sample than the 1st microbatch groups. - # For example, if microbatch 0 had group of 4 GPUs running for a sub-sample but microbatch 1 requires 6 GPUs - # If GPU 4,5 are done before 0,1,2,3 in microbatch 0, then they will be waiting for 0,1,2,3 to catch up which can lead to deadlock in comms. - # Unless we can stop 4,5 from executing which means partial syncs for each executing sub-sample. - num_subsamples, complete_cp_assignment = get_heterogeneous_cp_assignment( - data["cu_seqlens"], - config.max_seqlen_per_cp_rank, - config.context_parallel_size, - ) - # current_cp_assignment: list of lists, each inner list contains the - # CP ranks assigned to the sub-samples that are executing in the forward-backward loop. - # See function get_current_cp_assignment for more details. - current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, 0, rank) - data["cp_assignment"] = current_cp_assignment - # TODO: Make this a tensor so that tensor parallel broadcast works - # data["cp_assignment"] = torch.tensor(current_cp_assignment, dtype=torch.int32) - bound_args.arguments['data_iterator'] = RerunDataIterator(iter([data])) - # Run the 1st micro-microbatch - output_tensor, num_tokens = single_forward_step(*bound_args.args, **bound_args.kwargs) - total_num_tokens += num_tokens - # Run the N-1 backward steps, N-1 forward steps. - # Will be left with Nth backward step after this loop which is run in the original function. - for i in range(1, num_subsamples): - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - current_cp_assignment = get_current_cp_assignment(complete_cp_assignment, i, rank) - data["cp_assignment"] = current_cp_assignment - bound_args.arguments['data_iterator'] = RerunDataIterator(iter([data])) - output_tensor, num_tokens = single_forward_step(*bound_args.args, **bound_args.kwargs) - total_num_tokens += num_tokens - return output_tensor, total_num_tokens - return forward_func_wrapper - -def get_sub_sample_on_this_cp_rank(batch, current_cp_assignment, packed_seq_params, cp_rank): cu_lengths = packed_seq_params.cu_seqlens_q_padded - for i in range(len(current_cp_assignment)): - if current_cp_assignment[i] is not None: - assert cp_rank in current_cp_assignment[i], f"Current cp rank {cp_rank} is not part of the cp_assignment {current_cp_assignment[i]} given to this GPU" - start_index = cu_lengths[i] - end_index = cu_lengths[i+1] - cp_shard_ranks = current_cp_assignment[i] - break + start_index = cu_lengths[scheduled_id] + end_index = cu_lengths[scheduled_id+1] + # TODO (flexible HCP): New CP size also means new padding requirement. CP4 to CP3 changes padding requirement. for key, data in batch.items(): batch[key] = data[:, start_index:end_index] - - # TODO: Clean this up. Reduce code by calculating indices once + + # TODO (milestone 2): Enable this when we do DPxCP + # cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) + if local_cp_size > 1: + cp_group = parallel_state.get_hybrid_context_parallel_groups(group_size=local_cp_size) + else: + cp_group = None + sub_sample_packed_seq_params = PackedSeqParams( qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[i+1] - packed_seq_params.cu_seqlens_q[i]], device="cpu", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[i+1] - packed_seq_params.cu_seqlens_kv[i]], device="cpu", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[i+1] - packed_seq_params.cu_seqlens_q_padded[i]], device="cpu", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[i+1] - packed_seq_params.cu_seqlens_kv_padded[i]], device="cpu", pin_memory=True), - max_seqlen_q=torch.tensor([end_index - start_index], device="cpu", pin_memory=True), - max_seqlen_kv=torch.tensor([end_index - start_index], device="cpu", pin_memory=True), - cp_assignment=cp_shard_ranks, + cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[scheduled_id+1] - packed_seq_params.cu_seqlens_q[scheduled_id]], device="cpu", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[scheduled_id+1] - packed_seq_params.cu_seqlens_kv[scheduled_id]], device="cpu", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_q_padded[scheduled_id]], device="cpu", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_kv_padded[scheduled_id]], device="cpu", pin_memory=True), + max_seqlen_q=end_index - start_index, + max_seqlen_kv=end_index - start_index, + local_cp_size=local_cp_size, ) - return batch, cp_shard_ranks, sub_sample_packed_seq_params + # TODO: Should we return the sharded sample directly here? + + return batch, cp_group, sub_sample_packed_seq_params ###################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 369871b9fe2..7b0b0902006 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2414,10 +2414,10 @@ def _add_distributed_args(parser): group.add_argument('--max-seqlen-per-cp-rank', type=int, default=None, help='Maximum sequence length per CP rank. This is used to calculate the ' 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') - group.add_argument('--heterogeneous-context-parallel', action='store_true', default=False, - help='Enables heterogeneous context parallel. This is used to balance the workload ' - 'of each CP rank when we use packed samples with variable sequence lengths. Requires ' - '--max-seqlen-per-cp-rank to be set.') + group.add_argument('--hybrid-context-parallel', action='store_true', default=False, + help='Enables hybrid context parallel. This is used to balance the workload ' + 'of each CP rank when we use packed samples with variable sequence lengths. ' + 'Requires --max-seqlen-per-cp-rank to be set.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index c689a8a9c5f..5cc4691db6d 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -350,6 +350,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s use_sharp=args.use_sharp, context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, + hybrid_context_parallel=args.hybrid_context_parallel, expert_model_parallel_size=args.expert_model_parallel_size, num_distributed_optimizer_instances=args.num_distributed_optimizer_instances, expert_tensor_parallel_size=args.expert_tensor_parallel_size, From 771fc21263ba0863643e3605ce1df6d4492e9121 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 21 Aug 2025 16:44:26 -0700 Subject: [PATCH 12/81] Add SFT dataset support --- megatron/core/datasets/gpt_dataset.py | 3 + megatron/training/datasets/sft_dataset.py | 205 ++++++++++++++-------- megatron/training/utils.py | 68 +++++++ pretrain_gpt.py | 71 +++++++- 4 files changed, 267 insertions(+), 80 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index b80caaf6ca0..37d984c2a36 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -52,6 +52,9 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): object_storage_cache_path: Optional[str] = None """Path for caching indices for s3 or msc dataloading.""" + context_parallel_size: int = 1 + """Option to enable context parallelism""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 44547d5f4be..5f37a4c5b7e 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -68,82 +68,135 @@ def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowL def __len__(self) -> int: return self.num_samples + def _split_conversations(self, merged_conversations): + split_conversations = [] + current = [] + for msg in merged_conversations: + # Whenever we see a new system message, start a new conversation + if msg["role"] == "system": + if current: # If previously accumulating a conversation, then store it + split_conversations.append(current) + current = [msg] # Then start the new conversation + else: + current.append(msg) # Continue accumulating the current conversation + if current: # Store any remaining conversation + split_conversations.append(current) + return split_conversations + def __getitem__(self, idx: int) -> Dict[str, Any]: tokenizer = self.config.tokenizer - max_seq_len = self.config.sequence_length - - conversation_list = self.dataset[int(self.indices[idx % len(self.indices)])] - tokens, target = tokenizer.tokenize_conversation( - conversation_list, return_target=True, add_generation_prompt=False - ) - - # minus one to insert eos token - if len(tokens) > max_seq_len - 1: - if True: # TODO: when too long to fit in context, truncate left to right - tokens = tokens[: max_seq_len - 1] - target = target[: max_seq_len - 1] - else: # right to left - tokens = tokens[-(max_seq_len - 1) :] - target = target[-(max_seq_len - 1) :] - - # padding - num_tokens = len(tokens) + 1 - padding_len = max_seq_len - num_tokens - assert padding_len >= 0 - filler = [tokenizer.pad] * (padding_len + 1) - - tokens = np.array(tokens.tolist() + [tokenizer.eod] + filler, dtype=np.int64) - target = np.array(target.tolist() + [tokenizer.eod] + filler, dtype=np.int64) - - tokens = torch.tensor(tokens) - target = torch.tensor(target) - - tokens = tokens[:-1].contiguous() - target = target[1:].contiguous() - - loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( - max_seq_len, target, tokenizer.pad - ) - - if self.config.create_attention_mask: - ret = { - 'tokens': tokens, - 'labels': target, - 'attention_mask': attention_mask, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - else: - ret = { - 'tokens': tokens, - 'labels': target, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - - return ret - - def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): - """Build masks and position id for left to right model for SFT""" - - assert not self.config.reset_position_ids and not self.config.reset_attention_mask - - # Position ids. - position_ids = torch.arange(max_seq_len, dtype=torch.long) - - # Loss mask. - loss_mask = torch.ones(max_seq_len, dtype=torch.float) - loss_mask[target == pad_token] = 0.0 # mask paddings - loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts - - if self.config.create_attention_mask: - attention_mask = torch.tril( - torch.ones((seq_length, seq_length), device=data.device) - ).unsqueeze(0) - # Convert attention mask to binary: - attention_mask = attention_mask < 0.5 - else: - attention_mask = None - - return loss_mask, position_ids, attention_mask + pack_length = self.config.sequence_length + + merged_conversations = self.dataset[int(self.indices[idx % len(self.indices)])] + split_conversations = self._split_conversations(merged_conversations) + + def extend_with_padding(tokens, targets, positions, pad_len): + tokens.extend([pad] * pad_len) + targets.extend([pad] * pad_len) + positions.extend(range(positions[-1]+1, positions[-1]+1+pad_len)) + + pack_tokens = [] + pack_targets = [] + pack_positions = [] + cu_seqlens = [0] + eod = tokenizer.eod + pad = tokenizer.pad + # TODO(duncan): Track number of convs dropped and/or truncated and amount of end-padding + for conversation in split_conversations: + + tokens, targets = tokenizer.tokenize_conversation( + conversation, return_target=True, add_generation_prompt=False + ) + + tokens_list = tokens.tolist() + targets_list = targets.tolist() + + # Add EOD, unless it's already present + if tokens_list[-1] != eod: + tokens_list.append(eod) + targets_list.append(eod) + + pack_tokens.extend(tokens_list) + pack_targets.extend(targets_list) + + assert not self.config.reset_position_ids + pack_positions.extend(range(len(tokens_list))) + + if self.config.context_parallel_size > 1: + pad_granularity = self.config.context_parallel_size * 2 + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + + # TODO(duncan): Consider also padding to multiple of number of tokens here. This might + # be needed for efficiency (and potentially set via command-line argument). + + cu_seqlens.append(len(pack_tokens)) + + # Handle any necessary truncation + # + # Consider the case where the previous iteration led to + # len(pack_tokens) == pack_length. Then pack_tokens[pack_length-1] == eod. On this + # current iteration len(pack_tokens) >= pack_length + 1. Truncation here will then + # strip off the eod from the previous iteration and re-apply it. + # + # Consider the case where the previous iteration led to + # len(pack_tokens) == pack_length - 1. Then pack_tokens[pack_length-2] == eod. On + # this current iteration len(pack_tokens) >= pack_length + 1. Truncation here will + # then apply a second eod at location pack_tokens[pack_length-1]. So it is possible + # to have two eod tokens in a row. I'm not sure if this is a problem. + # + if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment + max_body = pack_length - 1 + # Truncate on the left. TODO(duncan): Consider optionally trunc. on the right + pack_tokens = pack_tokens[-max_body:] + pack_targets = pack_targets[-max_body:] + pack_tokens.extend([eod, pad]) + pack_targets.extend([eod, pad]) + pack_positions = pack_positions[:pack_length+1] + # Note len({pack_tokens, pack_targets, pack_positions}) should be pack_length + 1 + cu_seqlens[-1] = len(pack_tokens) - 1 + break + + # Handle any necessary padding + if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment + pad_len = pack_length + 1 - len(pack_tokens) + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + # Note len({pack_tokens, pack_targets, pack_positions}) should be pack_length + 1 + cu_seqlens[-1] = len(pack_tokens) - 1 + + assert len(pack_tokens) == pack_length + 1 + assert len(pack_targets) == pack_length + 1 + assert len(pack_positions) == pack_length + 1 + + # Align and convert to tensors + input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) + labels = torch.tensor(pack_targets[1:], dtype=torch.int64) + position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) + + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 # Mask paddings + loss_mask[labels == IGNORE_INDEX] = 0.0 # mask prompts + + # TODO(duncan): Optionally create an attention mask + assert not self.config.create_attention_mask and not self.config.reset_attention_mask + # attention_mask = None + + assert len(cu_seqlens) >= 2 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + # Calculating max_seqlen here, rather than incrementally above, because of possible + # effects of truncation and padding + adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor + + return { + 'tokens': input_ids, + 'labels': labels, + # 'attention_mask': attention_mask, # PyTorch collate cannot handle NoneType + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + } diff --git a/megatron/training/utils.py b/megatron/training/utils.py index a873a89514b..0b89ad26664 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -461,19 +461,50 @@ def _broadcast(item): else data["attention_mask"].cuda(non_blocking=True) ), 'position_ids': data["position_ids"].cuda(non_blocking=True), + 'cu_seqlens': ( + None + if "cu_seqlens" not in data + else data["cu_seqlens"].cuda(non_blocking=True) + ), + 'max_seqlen': ( + None + if "max_seqlen" not in data + else data["max_seqlen"].cuda(non_blocking=True) + ), } + def _broadcast_cu_seqlens(cu_seqlens): + dev = torch.cuda.current_device() + + n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) + n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + _broadcast(n_tensor) + + if n == 0: + buf = torch.empty(0, dtype=torch.int32, device=dev) + else: + assert isinstance(cu_seqlens, torch.Tensor) + assert cu_seqlens.dtype == torch.int32 + assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing" + buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous() + _broadcast(buf) + + if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast(batch['max_seqlen']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast(batch['max_seqlen']) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -482,6 +513,8 @@ def _broadcast(item): if args.mtp_num_layers is not None: _broadcast(batch['tokens']) _broadcast(batch['position_ids']) + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast(batch['max_seqlen']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) @@ -517,12 +550,39 @@ def _broadcast(item): device=torch.cuda.current_device(), ) + cu_seqlens = None + if args.sft: + max_seqlen = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + else: + max_seqlen = None + + def _broadcast_cu_seqlens(): + dev = torch.cuda.current_device() + + n = torch.empty((), dtype=torch.int64, device=dev) + _broadcast(n) + n = int(n.item()) + + if n == 0: + cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev) + else: + cu_seqlens = torch.empty((args.micro_batch_size, n), dtype=torch.int32, device=dev) + _broadcast(cu_seqlens) + + return cu_seqlens if n > 0 else None + if args.pipeline_model_parallel_size == 1: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) + cu_seqlens = _broadcast_cu_seqlens() + _broadcast(max_seqlen) elif mpu.is_pipeline_first_stage(): labels = None @@ -531,6 +591,8 @@ def _broadcast(item): _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) + cu_seqlens = _broadcast_cu_seqlens() + _broadcast(max_seqlen) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -539,9 +601,13 @@ def _broadcast(item): if args.mtp_num_layers is not None: _broadcast(tokens) _broadcast(position_ids) + cu_seqlens = _broadcast_cu_seqlens() + _broadcast(max_seqlen) else: tokens = None position_ids = None + cu_seqlens = None + max_seqlen = None _broadcast(labels) _broadcast(loss_mask) @@ -553,6 +619,8 @@ def _broadcast(item): 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 77346348e34..7d9b941278e 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -16,6 +16,11 @@ from megatron.training import get_tokenizer from megatron.core import mpu from megatron.core.enums import ModelType +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_rank, + get_context_parallel_world_size, +) from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.enums import ModelType @@ -31,7 +36,7 @@ ) from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer.spec_utils import import_module -from megatron.core.utils import StragglerDetector +from megatron.core.utils import is_te_min_version, StragglerDetector from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.training.utils import ( @@ -55,6 +60,16 @@ except ImportError: has_nvidia_modelopt = False +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None + stimer = StragglerDetector() @@ -211,7 +226,41 @@ def get_batch(data_iterator): batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) + # batch = get_batch_on_this_cp_rank(batch) + + cu_seqlens = batch['cu_seqlens'] + if cu_seqlens is None: + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore + else: # Packed THD format + assert ( + cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 + ), "micro-batch-size must be 1 for packing" + cu_seqlens = cu_seqlens[0] + batch['cu_seqlens'] = cu_seqlens + + max_seqlen = batch['max_seqlen'] + assert max_seqlen.dim() == 1 + # TODO(duncan): can this be kept as a 0-D tensor? + batch['max_seqlen'] = int(max_seqlen[0].item()) + + cp_size = get_context_parallel_world_size() + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + cp_rank = get_context_parallel_rank() + index = tex.thd_get_partitioned_indices( + cu_seqlens, + batch['tokens'].size(1), + cp_size, + cp_rank, + ) + for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + continue + batch[key] = data.index_select(1, index) return batch.values() @@ -297,7 +346,20 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa timers('batch-generator', log_level=2).start() global stimer with stimer(bdata=True): - tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, max_seqlen = get_batch(data_iterator) + if cu_seqlens is None: + packed_seq_params = None + else: + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + timers('batch-generator').stop() with stimer: @@ -313,7 +375,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa return schedule_plan, partial(loss_func, loss_mask, model=model) else: output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask, packed_seq_params=packed_seq_params ) # [ModelOpt]: model is needed to access ModelOpt distillation losses @@ -351,6 +413,7 @@ def core_gpt_dataset_config_from_args(args): create_attention_mask=args.create_attention_mask_in_dataloader, object_storage_cache_path=args.object_storage_cache_path, mid_level_dataset_surplus=args.mid_level_dataset_surplus, + context_parallel_size=args.context_parallel_size, ) From 3a7e50a1010da36fd08cec8cedd1c57a85ddc5ac Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 21 Aug 2025 17:02:07 -0700 Subject: [PATCH 13/81] New updates to use DataIterator wrapper --- .../pipeline_parallel/hybrid_cp_schedule.py | 51 ++++++++++++++++++- megatron/core/utils.py | 8 +-- megatron/training/training.py | 4 ++ 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 8ee6ec02b32..05b145af49d 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -7,6 +7,49 @@ import math import heapq +class HybridCPWrapper(): + """ + A wrapper class that wraps around any existing dataset and prints samples + as they are requested from the original dataset. + + This wrapper implements the standard PyTorch Dataset interface and can be + used with any dataset that follows the same interface. + + Args: + dataset: The original dataset to wrap around + print_format: Format string for printing samples (default: "Sample {idx}: {sample}") + print_func: Function to use for printing (default: print) + max_print_length: Maximum length of sample to print (default: 200 chars) + print_every: Print every Nth sample (default: 1, print all) + """ + + def __init__( + self, + data_iterator, + config, + ): + self.data_iterator = data_iterator + self.sample_count = 0 + self.config = config + self.cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank) + + def __iter__(self): + """Return self as an iterator.""" + return self + + def __next__(self) -> Any: + """ + Get the next item from the dataset, pull scheduling metadata and return it. + """ + sample = next(self.data_iterator) + assert "cu_lengths" in sample, "cu_lengths must be in the sample" + # TODO(milestone 2): Get cu_lengths and all-gather the entire global batch worth cu_lengths and then perform the scheduling. + # But why should this scheduling information be integrated back into the data? + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(sample, self.config) + sample["groups"] = groups + sample["sample_id_groups"] = sample_id_groups + return sample + class BalancedCPScheduler: def __init__(self, max_seq_len_per_rank: int): self.max_seq_len_per_rank = max_seq_len_per_rank @@ -554,7 +597,9 @@ def hybrid_context_parallel_forward_backward( with no_sync_func(): for i in range(num_microbatches - 1): data = next(data_iterator) - groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + # groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + groups = data["groups"] + sample_id_groups = data["sample_id_groups"] for j in range(len(groups)): # Get sub-samples for the current CP rank # TODO: Update to DPxCP rank when milestone 2 @@ -593,7 +638,9 @@ def hybrid_context_parallel_forward_backward( # TODO: Call scheduler here. with no_sync_func(): data = next(data_iterator) - groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + # groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) + groups = data["groups"] + sample_id_groups = data["sample_id_groups"] for j in range(len(groups) - 1): sample_ids_per_group = sample_id_groups[j][parallel_state.get_context_parallel_rank()] for k in range(len(sample_ids_per_group)): diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 89dfdfdad16..d2e0cd85560 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1897,10 +1897,10 @@ def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_se sub_sample_packed_seq_params = PackedSeqParams( qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[scheduled_id+1] - packed_seq_params.cu_seqlens_q[scheduled_id]], device="cpu", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[scheduled_id+1] - packed_seq_params.cu_seqlens_kv[scheduled_id]], device="cpu", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_q_padded[scheduled_id]], device="cpu", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_kv_padded[scheduled_id]], device="cpu", pin_memory=True), + cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[scheduled_id+1] - packed_seq_params.cu_seqlens_q[scheduled_id]], device="cuda", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[scheduled_id+1] - packed_seq_params.cu_seqlens_kv[scheduled_id]], device="cuda", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_q_padded[scheduled_id]], device="cuda", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_kv_padded[scheduled_id]], device="cuda", pin_memory=True), max_seqlen_q=end_index - start_index, max_seqlen_kv=end_index - start_index, local_cp_size=local_cp_size, diff --git a/megatron/training/training.py b/megatron/training/training.py index 6f6ac91a652..de3d44d14f9 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -96,6 +96,7 @@ get_pp_rank_microbatches, get_schedule_table, ) +from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPWrapper from megatron.core.num_microbatches_calculator import ( destroy_num_microbatches_calculator, get_current_global_batch_size, @@ -2085,6 +2086,9 @@ def train( energy_monitor = get_energy_monitor() one_logger = get_one_logger() + if args.hybrid_context_parallel: + train_data_iterator = HybridCPWrapper(train_data_iterator, config) + if args.run_workload_inspector_server: try: from workload_inspector.utils.webserver import run_server From bd2321d468d9faf124da61e3106bcc16cecdfd7a Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 21 Aug 2025 17:20:15 -0700 Subject: [PATCH 14/81] Test scripts --- examples/run_hybrid_cp.sh | 199 ++++++++++++++++++++++++++++++++++++ interactive_pytorch25.02.sh | 29 ++++++ 2 files changed, 228 insertions(+) create mode 100755 examples/run_hybrid_cp.sh create mode 100755 interactive_pytorch25.02.sh diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh new file mode 100755 index 00000000000..a085b569456 --- /dev/null +++ b/examples/run_hybrid_cp.sh @@ -0,0 +1,199 @@ +#!/bin/bash + +#SBATCH -A coreai_dlalgo_llm +# DFW: batch +# OCI-NRT: batch_block1 +# OCI-IAD: batch_block1,batch_block3,batch_block4,backfill_block1,backfill_block2,backfill_block3,backfill_block4 +#SBATCH -p batch +#SBATCH -t 00:30:00 +#SBATCH --mem=0 +#SBATCH --ntasks-per-node=8 +#SBATCH --nodes=4 +#SBATCH --exclusive +#SBATCH --gpus-per-node=8 +#SBATCH --job-name=video_sft_stage3_qwen_2p5_7b_radio_research_cp_0518_exp + +export NCCL_IB_SL=1 +export TOKENIZERS_PARALLELISM="false" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +#export NVTE_DEBUG=1 +#export NVTE_DEBUG_LEVEL=2 + +USER=$SLURM_JOB_USER + +# Auto-detect batch or interactive mode. +which srun +BATCH=$((1-$?)) + +DEBUG=0 +USE_TILING=1 +USE_CP=1 +USE_TE_CE=0 +USE_FLASH_ATTN=0 +USE_FSDP=0 +USE_CUSTOM_FSDP=0 +PROFILE=0 + +# Remember to update model and job name if running in batch mode!! +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="interactive_hybrid_cp_${DATETIME}" + DEBUG=1 +else + MODEL_NAME="hybrid_cp_0821_exp" +fi + +WORKSPACE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl" + +SEQ_LEN=131072 #131072 #81920 #65536 + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=2 + NW=4 + AD=0.0 + HD=0.0 + LI=1 + + # EXTRA_ARGS="--deterministic-mode --use-cpu-initialization" + + NONDETERMINISTIC_ATTN=1 + + NUM_GPU=8 + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + #export NCCL_ALGO=Tree + #export CUBLAS_WORKSPACE_CONFIG=:4096:8 +else + MBZ=1 + BZ=64 + NW=8 + AD=0.0 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 + NUM_GPU=8 +fi + +if [[ $USE_CP -eq 1 ]]; then + if [[ $BATCH -eq 1 ]]; then + CP_SIZE=8 + else + CP_SIZE=8 + fi + EXTRA_ARGS+=" --context-parallel-size ${CP_SIZE} " +fi + +if [[ $USE_TE_CE -eq 1 ]]; then + EXTRA_ARGS+=" --enable-te-ce --cross-entropy-loss-fusion " +fi + +if [[ $PROFILE -eq 1 ]]; then + EXTRA_ARGS+="--profile --profile-step-start 1 --profile-step-end 2 --profile-ranks 0 1 2 3 4 5 6 7" +fi + +# CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" +TP=1 +EXTRA_ARGS+=" --ckpt-format torch --use-distributed-optimizer " +# EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " +EXTRA_ARGS+=" --recompute-vision-projection " +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " + +OPTIONS=" \ + --use-checkpoint-args \ + --disable-bias-linear \ + --sft \ + --tokenizer-type SFTTokenizer \ + --sft-tokenizer-prompt-format nemotron-h-aligned + --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size 1 \ + --num-layers 2 \ + --hidden-size 2048 \ + --ffn-hidden-size 8192 \ + --add-qkv-bias \ + --num-attention-heads 16 \ + --use-te \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples 100000 \ + --lr-warmup-samples 20000 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-5 \ + --min-lr 0.0 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 999999 \ + --data-path ${DATA_TRAIN} \ + --save-interval 1000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --use-mcore-models \ + --no-mmap-bin-files \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 0.05 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --init-method-std 0.014 \ + --bf16 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --packing-seq-length ${SEQ_LEN} \ + --calculate-per-token-loss \ +" + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + if [[ $PROFILE -eq 1 ]]; then + nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o hetero_cp_iter1_%q{SLURM_PROCID} torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + else + torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + fi +else + if [[ $PROFILE -eq 1 ]]; then + run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o stage3_iter3_4_tp2cp4_w_tpcp_shard_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" + else + run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" + fi + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image /lustre/fsw/portfolios/coreai/users/pmannan/workspace/megatron_vlm_25.02_stage3-te_api.sqsh \ + --container-mounts "/lustre" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi \ No newline at end of file diff --git a/interactive_pytorch25.02.sh b/interactive_pytorch25.02.sh new file mode 100755 index 00000000000..3333cd934df --- /dev/null +++ b/interactive_pytorch25.02.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# H100 +CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-bd46613.sqsh" +CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-main-6.0.1.sqsh" +CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr.sqsh" +CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr-onelogger.sqsh" +CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr-onelogger-max-pr2.sqsh" +CONTAINER_IMAGE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace/megatron_vlm_25.02-te_api.sqsh" + +# Set partitions based on hostname +if [[ $(hostname) == *"oci-iad"* ]]; then + PARTITIONS="interactive,batch_singlenode,backfill_singlenode,backfill_block1,backfill_block3,backfill_block4,batch_block1,batch_block2,batch_block3,batch_block4" +elif [[ $(hostname) == *"cw-dfw"* ]]; then + PARTITIONS="interactive,batch" +elif [[ $(hostname) == *"oci-nrt"* ]]; then + PARTITIONS="interactive,batch_block1,backfill,batch_singlenode" +else + PARTITIONS="interactive" +fi + +srun -p ${PARTITIONS} -A coreai_dlalgo_genai -N 1 --pty \ + --container-image /lustre/fsw/portfolios/llmservice/users/matthieul/docker/megatron-dev-img-05142025-pytorch-dev-te-cd37379-editable-energon-mamba-fix-vlmeval-pad-conv.sqsh \ + --container-mounts "/lustre" \ + --gpus 8 \ + --exclusive \ + --job-name "coreai_dlalgo_genai-megatron-dev:interactive" \ + -t 1:00:00 \ + bash -l \ No newline at end of file From 6a1ab4fdb49a2d7cd668ead7febb509d612a5fdb Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 21 Aug 2025 19:14:37 -0700 Subject: [PATCH 15/81] Minor fixes --- examples/run_hybrid_cp.sh | 13 +++++++------ .../core/pipeline_parallel/hybrid_cp_schedule.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh index a085b569456..775d42f0189 100755 --- a/examples/run_hybrid_cp.sh +++ b/examples/run_hybrid_cp.sh @@ -38,7 +38,7 @@ PROFILE=0 # Remember to update model and job name if running in batch mode!! if [[ $BATCH -eq 0 ]]; then DATETIME=`date +'%y-%m-%d-%H-%M-%S'` - MODEL_NAME="interactive_hybrid_cp_${DATETIME}" + MODEL_NAME="interactive_hybrid_cp" DEBUG=1 else MODEL_NAME="hybrid_cp_0821_exp" @@ -52,6 +52,9 @@ OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" FINETUNE_DIR=${OUTPUT}/checkpoints LOGS_DIR="${OUTPUT}/logs" TENSORBOARD_DIR="${OUTPUT}/tensorboard" +DATACACHE_DIR="${OUTPUT}/data_cache" + +export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache" DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl" @@ -107,7 +110,6 @@ fi TP=1 EXTRA_ARGS+=" --ckpt-format torch --use-distributed-optimizer " # EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " -EXTRA_ARGS+=" --recompute-vision-projection " export CUDA_DEVICE_MAX_CONNECTIONS=1 EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " @@ -131,12 +133,12 @@ OPTIONS=" \ --swiglu \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size 1 \ + --rerun-mode disabled \ --num-layers 2 \ --hidden-size 2048 \ --ffn-hidden-size 8192 \ --add-qkv-bias \ --num-attention-heads 16 \ - --use-te \ --num-workers ${NW} \ --exit-duration-in-mins 230 \ --seq-length ${SEQ_LEN} \ @@ -155,8 +157,9 @@ OPTIONS=" \ --save-interval 1000 \ --save ${FINETUNE_DIR} \ --load ${FINETUNE_DIR} \ - --dataloader-save ${FINETUNE_DIR}/dataloader \ + --data-cache-path ${DATACACHE_DIR} \ --use-mcore-models \ + --no-create-attention-mask-in-dataloader \ --no-mmap-bin-files \ --split 100,0,0 \ --clip-grad 1.0 \ @@ -165,11 +168,9 @@ OPTIONS=" \ --adam-beta2 0.999 \ --init-method-std 0.014 \ --bf16 \ - --dataloader-type external \ --tensorboard-dir ${TENSORBOARD_DIR} \ ${EXTRA_ARGS} \ --distributed-timeout-minutes 60 \ - --packing-seq-length ${SEQ_LEN} \ --calculate-per-token-loss \ " diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 05b145af49d..e3bd8c89b09 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -2,8 +2,9 @@ from megatron.core import parallel_state from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.rerun_state_machine import RerunDataIterator -from typing import Any, Callable, Dict, Deque, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from functools import lru_cache +from collections import deque import math import heapq From e01583a18a201d124c47d8d47e6da58c20b32d6d Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 22 Aug 2025 15:47:52 -0700 Subject: [PATCH 16/81] SFT Dataset + GPT with Hybrid CP working example --- examples/run_hybrid_cp.sh | 13 +- .../core/extensions/transformer_engine.py | 5 +- megatron/core/packed_seq_params.py | 3 +- megatron/core/parallel_state.py | 1 + .../pipeline_parallel/hybrid_cp_schedule.py | 126 ++---------------- megatron/core/utils.py | 3 + megatron/training/utils.py | 39 ++++-- pretrain_gpt.py | 66 +++++---- 8 files changed, 105 insertions(+), 151 deletions(-) diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh index 775d42f0189..7123da2496c 100755 --- a/examples/run_hybrid_cp.sh +++ b/examples/run_hybrid_cp.sh @@ -33,7 +33,7 @@ USE_TE_CE=0 USE_FLASH_ATTN=0 USE_FSDP=0 USE_CUSTOM_FSDP=0 -PROFILE=0 +PROFILE=1 # Remember to update model and job name if running in batch mode!! if [[ $BATCH -eq 0 ]]; then @@ -62,7 +62,7 @@ SEQ_LEN=131072 #131072 #81920 #65536 if [[ $DEBUG -eq 1 ]]; then MBZ=1 - BZ=2 + BZ=8 NW=4 AD=0.0 HD=0.0 @@ -103,7 +103,7 @@ if [[ $USE_TE_CE -eq 1 ]]; then fi if [[ $PROFILE -eq 1 ]]; then - EXTRA_ARGS+="--profile --profile-step-start 1 --profile-step-end 2 --profile-ranks 0 1 2 3 4 5 6 7" + EXTRA_ARGS+="--profile --profile-step-start 7 --profile-step-end 8 " fi # CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" @@ -172,12 +172,17 @@ OPTIONS=" \ ${EXTRA_ARGS} \ --distributed-timeout-minutes 60 \ --calculate-per-token-loss \ + --hybrid-context-parallel \ + --max-seqlen-per-cp-rank 16384 \ + --attention-backend flash \ " +# --hybrid-context-parallel \ +# --max-seqlen-per-cp-rank 16384 \ # Interactive or batch mode if [[ $BATCH -eq 0 ]]; then if [[ $PROFILE -eq 1 ]]; then - nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o hetero_cp_iter1_%q{SLURM_PROCID} torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} else torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} fi diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 81002e357f0..ace4186cebc 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -842,6 +842,8 @@ def __init__( self.kept_packed_seq_params = set( field.name for field in dataclasses.fields(PackedSeqParams) ) + self.kept_packed_seq_params.discard("scheduled_id") + if get_te_version() < PkgVersion("1.3.0"): # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H # copies (#555) @@ -887,10 +889,11 @@ def forward( if packed_seq_params is not None: if packed_seq_params.cp_group is not None: self.cp_group = packed_seq_params.cp_group + # TODO(pmannan): Add cp_comm_type super().set_context_parallel_group( self.cp_group, torch.distributed.get_process_group_ranks(self.cp_group), - super().cp_stream, + TEDotProductAttention.cp_stream, ) elif packed_seq_params.local_cp_size is not None: super().set_context_parallel_group(None, None, None) diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index b93cc970ffb..ea5e9463187 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -20,7 +20,6 @@ class PackedSeqParams: cu_seqlens_kv_padded: Tensor = None max_seqlen_q: int = None max_seqlen_kv: int = None - cp_assignment: List[int] = None local_cp_size: int = None - schedule_id: int = None + scheduled_id: int = None cp_group: dist.ProcessGroup = None diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 526a873662d..adda3270ca5 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -5,6 +5,7 @@ import os import warnings from datetime import timedelta +from math import log2 from typing import Callable, List, Optional import numpy as np diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index e3bd8c89b09..3d63a43a875 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from functools import lru_cache from collections import deque -import math +from math import ceil, log2 import heapq class HybridCPWrapper(): @@ -43,8 +43,8 @@ def __next__(self) -> Any: Get the next item from the dataset, pull scheduling metadata and return it. """ sample = next(self.data_iterator) - assert "cu_lengths" in sample, "cu_lengths must be in the sample" - # TODO(milestone 2): Get cu_lengths and all-gather the entire global batch worth cu_lengths and then perform the scheduling. + assert "cu_seqlens" in sample, "cu_seqlens must be in the sample" + # TODO(milestone 2): Get cu_seqlens and all-gather the entire global batch worth cu_seqlens and then perform the scheduling. # But why should this scheduling information be integrated back into the data? groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(sample, self.config) sample["groups"] = groups @@ -75,105 +75,6 @@ def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): cp_size = self.gpus_needed(seq_length) return (seq_length * seq_length) / cp_size - def get_heterogeneous_cp_assignment( - self, - cu_seqlens: List[int], - max_seqlen_per_cp_rank: int, - cp_size: int, - compute_estimator: Optional[Callable] = None, - ): - """ - cu_seqlens: list of sub-sample sequence lengths - max_seqlen_per_cp_rank: list of max sequence length per CP rank - cp_size: total number of CP ranks - flops_calculator: function to calculate flops from cu_seqlens - - Returns: - start_time[j]: the time job j begins - assignment[j]: list of resource IDs assigned to job j - """ - if compute_estimator is None: - compute_estimator = self.get_total_workload - #TODO: Account for nvlink - IBlink boundaries. How to dynamically get this information? - #TODO: is cu_seqlens a list of ints or tensors? Correct the type hint - sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] - cp_rank = torch.distributed.get_rank() # Get rank from CP group - cp_size_per_sample = [math.ceil(x / max_seqlen_per_cp_rank) for x in sub_sample_lens] - total_workload_per_cp_rank = [ - compute_estimator(x, cp) for x, cp in zip(sub_sample_lens, cp_size_per_sample) - ] - # Sort workloads in descending order - num_sub_samples = len(sub_sample_lens) - jobs = sorted(range(num_sub_samples), key=lambda j: total_workload_per_cp_rank[j], reverse=True) - - # a min-heap of free resource IDs (CP rank IDs) - if len(self.free_resources) == 0: - self.free_resources = list(range(cp_size)) - heapq.heapify(self.free_resources) - - # events: (release_time, [list of resource IDs freeing then]) - events = [] - # Trackers used in scheduling algorithm - current_time = 0 - start_time = [None] * num_sub_samples - assignment = [None] * num_sub_samples - num_sub_samples_processed = 0 - - while jobs: - made_progress = True - # try to schedule any sub-sample that fits in the currently free resources - while made_progress: - made_progress = False - for j in list(jobs): - if cp_size_per_sample[j] <= len(self.free_resources): - # grab the lowest‐ID CP ranks available - assigned = [heapq.heappop(self.free_resources) for _ in range(cp_size_per_sample[j])] - if cp_rank in assigned: - num_sub_samples_processed += 1 - start_time[j] = current_time - assignment[j] = assigned - # schedule the completion of the sub-sample compute - release_time = current_time + total_workload_per_cp_rank[j] - heapq.heappush(events, (release_time, assigned)) - jobs.remove(j) - made_progress = True - break - - # if nothing fits right now, advance to the next release event - if not made_progress and events: - t, freed_ids = heapq.heappop(events) - current_time = t - for rid in freed_ids: - heapq.heappush(self.free_resources, rid) - elif not events: - # should not happen when cp_size ≥ max(cp_size_per_sample) - break - - return num_sub_samples_processed, assignment - - # def get_per_microbatch_assignment( - # self, - # cu_seqlens: List[int], - # cp_size: int, - # compute_estimator: Optional[Callable] = None, - # ): - # """ - # cu_seqlens: list of sub-sample sequence lengths - # cp_size: total number of CP ranks - # compute_estimator: function to calculate flops from cu_seqlens - # """ - # micro_batches = [] - # exec_times = [] - # if compute_estimator is None: - # compute_estimator = self.get_total_workload - # sub_sample_lens = cu_seqlens[0][1:] - cu_seqlens[0][:-1] - # while sub_sample_lens: - # sub_sample_lens = sorted(sub_sample_lens, reverse=True) - # microbatch, sub_sample_lens, exec_times = self.next_hdp_group(sub_sample_lens, compute_estimator, cp_size) - # micro_batches.append(microbatch) - # exec_times.append(exec_times) - # return micro_batches - @lru_cache(maxsize=128) def gpus_needed(self, seq_len: int) -> int: return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) @@ -528,20 +429,19 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members def get_groups_and_subsamples( self, data, - model, config, ): # TODO: Protect for model parallelism # TODO: Reduce access to file system as much as possible. groups = [] sample_id_groups = [] - assert "cu_lengths" in data, ( + assert "cu_seqlens" in data, ( "data must have a cu_seqlens attribute to define the valid sequenece lengths " "of each sub-sample in a packed sample to use hybrid context parallel" ) # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. # TODO (Milestone 2): Sample ID logic will have to change once we have global batch - sample_id_seqlens = [(i, int(data["cu_lengths"][0][i+1] - data["cu_lengths"][0][i])) for i in range(0, data["cu_lengths"][0].shape[0] - 1)] + sample_id_seqlens = [(i, int(data["cu_seqlens"][0][i+1] - data["cu_seqlens"][0][i])) for i in range(0, data["cu_seqlens"][0].shape[0] - 1)] sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) while sample_id_seqlens: mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group(sample_id_seqlens, self.get_total_workload, config.context_parallel_size) @@ -611,8 +511,8 @@ def hybrid_context_parallel_forward_backward( partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) if partner_cp_size == 0: assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = partner_cp_size - data["scheduled_id"] = sub_sample_id + data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([data])) # TODO: Change data iterator to the right sub-sample # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. @@ -651,8 +551,8 @@ def hybrid_context_parallel_forward_backward( partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) if partner_cp_size == 0: assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = partner_cp_size - data["scheduled_id"] = sub_sample_id + data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) # TODO: What else should I update in data so that we can get the right sub-sample? new_data_iterator = RerunDataIterator(iter([data])) # TODO: Change data iterator to the right sub-sample @@ -683,8 +583,8 @@ def hybrid_context_parallel_forward_backward( for k in range(len(sample_ids_per_group) - 1): sub_sample_id = sample_ids_per_group[k] partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) - data["local_cp_size"] = partner_cp_size - data["scheduled_id"] = sub_sample_id + data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) # TODO: What else should I update in data so that we can get the right sub-sample? new_data_iterator = RerunDataIterator(iter([data])) # TODO: Change data iterator to the right sub-sample @@ -711,8 +611,8 @@ def hybrid_context_parallel_forward_backward( partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) if partner_cp_size == 0: assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = partner_cp_size - data["scheduled_id"] = sub_sample_id + data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) # TODO: What else should I update in data so that we can get the right sub-sample? new_data_iterator = RerunDataIterator(iter([data])) # TODO: Change data iterator to the right sub-sample diff --git a/megatron/core/utils.py b/megatron/core/utils.py index d2e0cd85560..3845dd837e0 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -29,6 +29,7 @@ from megatron.core import config from megatron.core.package_info import __version__ as mcore_version +from megatron.core.packed_seq_params import PackedSeqParams try: from torch.distributed._tensor import DTensor @@ -1886,6 +1887,8 @@ def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_se end_index = cu_lengths[scheduled_id+1] # TODO (flexible HCP): New CP size also means new padding requirement. CP4 to CP3 changes padding requirement. for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'max_seqlen', 'scheduled_id', 'local_cp_size'}: + continue batch[key] = data[:, start_index:end_index] # TODO (milestone 2): Enable this when we do DPxCP diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 0b89ad26664..3a168bb7ea1 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -471,6 +471,16 @@ def _broadcast(item): if "max_seqlen" not in data else data["max_seqlen"].cuda(non_blocking=True) ), + 'scheduled_id': ( + None + if "scheduled_id" not in data + else data["scheduled_id"].cuda(non_blocking=True) + ), + 'local_cp_size': ( + None + if "local_cp_size" not in data + else data["local_cp_size"].cuda(non_blocking=True) + ), } def _broadcast_cu_seqlens(cu_seqlens): @@ -498,6 +508,8 @@ def _broadcast_cu_seqlens(cu_seqlens): _broadcast(batch['position_ids']) _broadcast_cu_seqlens(batch['cu_seqlens']) _broadcast(batch['max_seqlen']) + _broadcast(batch['scheduled_id']) + _broadcast(batch['local_cp_size']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) @@ -551,14 +563,21 @@ def _broadcast_cu_seqlens(cu_seqlens): ) cu_seqlens = None - if args.sft: - max_seqlen = torch.empty( - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - else: - max_seqlen = None + max_seqlen = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + scheduled_id = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + local_cp_size = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) def _broadcast_cu_seqlens(): dev = torch.cuda.current_device() @@ -583,6 +602,8 @@ def _broadcast_cu_seqlens(): _broadcast(position_ids) cu_seqlens = _broadcast_cu_seqlens() _broadcast(max_seqlen) + _broadcast(scheduled_id) + _broadcast(local_cp_size) elif mpu.is_pipeline_first_stage(): labels = None @@ -621,6 +642,8 @@ def _broadcast_cu_seqlens(): 'position_ids': position_ids, 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, + 'scheduled_id': scheduled_id, + 'local_cp_size': local_cp_size, } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 7d9b941278e..64f7f392eda 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -36,7 +36,7 @@ ) from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer.spec_utils import import_module -from megatron.core.utils import is_te_min_version, StragglerDetector +from megatron.core.utils import is_te_min_version, StragglerDetector, get_sub_sample_on_this_cp_rank from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.training.utils import ( @@ -228,21 +228,41 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism # batch = get_batch_on_this_cp_rank(batch) - cu_seqlens = batch['cu_seqlens'] + cu_seqlens = batch.pop('cu_seqlens') + max_seqlen = batch.pop('max_seqlen') + scheduled_id = batch.pop('scheduled_id') + local_cp_size = batch.pop('local_cp_size') + if scheduled_id is not None: + scheduled_id = int(scheduled_id.item()) + local_cp_size = int(local_cp_size.item()) + + if cu_seqlens is not None: + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens[0], + cu_seqlens_kv=cu_seqlens[0], + cu_seqlens_q_padded=cu_seqlens[0], + cu_seqlens_kv_padded=cu_seqlens[0], + max_seqlen_q=int(max_seqlen[0].item()), + max_seqlen_kv=int(max_seqlen[0].item()), + ) + else: + packed_seq_params = None + if cu_seqlens is None: # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore - else: # Packed THD format - assert ( - cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 - ), "micro-batch-size must be 1 for packing" + elif local_cp_size is None: # Packed THD format + # assert ( + # cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 + # ), "micro-batch-size must be 1 for packing" cu_seqlens = cu_seqlens[0] - batch['cu_seqlens'] = cu_seqlens + # batch['cu_seqlens'] = cu_seqlens - max_seqlen = batch['max_seqlen'] + # max_seqlen = batch['max_seqlen'] assert max_seqlen.dim() == 1 # TODO(duncan): can this be kept as a 0-D tensor? - batch['max_seqlen'] = int(max_seqlen[0].item()) + # batch['max_seqlen'] = int(max_seqlen[0].item()) cp_size = get_context_parallel_world_size() if cp_size > 1: # slice batch along sequence dimension for context parallelism @@ -261,8 +281,20 @@ def get_batch(data_iterator): if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: continue batch[key] = data.index_select(1, index) + else: # Hybrid CP format + assert local_cp_size is not None + assert scheduled_id is not None + batch, cp_group, packed_seq_params = get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params) + packed_seq_params.cp_group = cp_group + if cp_group is not None and cp_group.size() > 1: + # print(f"rank: {torch.distributed.get_rank()} has partner ranks {torch.distributed.get_process_group_ranks(cp_group)} and got keys: {batch.keys()}") + batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) + # print(f"rank: {torch.distributed.get_rank()} has partner ranks {torch.distributed.get_process_group_ranks(cp_group)} and got tokens: {batch['tokens'].shape}") + # else: + # print(f"rank: {torch.distributed.get_rank()} has no partner ranks and got tokens: {batch['tokens'].shape}") - return batch.values() + + return (*batch.values(), packed_seq_params) # define spiky loss as a loss that's 10x the max loss observed @@ -346,19 +378,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa timers('batch-generator', log_level=2).start() global stimer with stimer(bdata=True): - tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, max_seqlen = get_batch(data_iterator) - if cu_seqlens is None: - packed_seq_params = None - else: - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - ) + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator) timers('batch-generator').stop() From 286129cbec8e46b1e32586efb01e003f2d5a14d3 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 22 Aug 2025 22:47:09 -0700 Subject: [PATCH 17/81] Remove new SFT loss calculation in favor of regular loss calc --- megatron/training/training.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index de3d44d14f9..6989f1d2f72 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1489,28 +1489,14 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch for key in losses_reduced[0].keys(): val = [x[key].view(-1) for x in losses_reduced] if val[0].numel() == 2: - if args.sft: - # in mcore the normalization happens on micro batch instead of global - val = torch.vstack(val) - val = val[:, 0] / val[:, 1] - val = val.mean() - torch.distributed.all_reduce( - val, - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - val /= torch.distributed.get_world_size( - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - loss_reduced[key] = val - else: - # there is one dict per microbatch. in new reporting, we average - # over the total number of tokens across the global batch. - val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce( - val, - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - loss_reduced[key] = val[0] / val[1] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + loss_reduced[key] = val[0] / val[1] elif val[0].numel() == 1: # legacy behavior, we average over the number of microbatches val = torch.cat(val).mean() From 2674be82487d6f0a19eafa4ab776c4815f22af15 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 29 Aug 2025 15:19:35 -0700 Subject: [PATCH 18/81] Milestone 2: Add global batch scheduling, data movement and training --- megatron/core/datasets/gpt_dataset.py | 3 + megatron/core/parallel_state.py | 40 +- .../pipeline_parallel/hybrid_cp_schedule.py | 361 ++++++++++++------ megatron/legacy/data/data_samplers.py | 79 +++- megatron/training/datasets/sft_dataset.py | 3 +- megatron/training/training.py | 4 +- pretrain_gpt.py | 42 +- 7 files changed, 381 insertions(+), 151 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 37d984c2a36..e14bb966c40 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -55,6 +55,9 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): context_parallel_size: int = 1 """Option to enable context parallelism""" + data_parallel_size: int = 1 + """Option to enable data parallelism""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index adda3270ca5..194118a1306 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -107,7 +107,7 @@ # Hierarchical context parallel groups _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = None # Hybrid context parallel groups -_HYBRID_CP_GROUPS = {} +_HYBRID_DP_CP_GROUPS = {} # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None @@ -366,8 +366,8 @@ def create_hierarchical_groups( assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) return hierarchical_groups, hierarchical_groups_gloo -def create_hybrid_cp_groups(rank, ranks, pg_options): - hybrid_cp_groups = {} +def create_hybrid_dp_cp_groups(rank, ranks, pg_options): + hybrid_dp_cp_groups = {} # Generate group for every power of 2 up to the number of CP ranks # We limit the allowed group sizes in order to avoid excessive overhead. group_sizes = [2**i for i in range(int(log2(len(ranks))))][1:] @@ -376,12 +376,12 @@ def create_hybrid_cp_groups(rank, ranks, pg_options): group = create_group( ranks[i:i+group_size], pg_options=pg_options, - group_desc=f"HYBRID_CP_GROUP_{group_size}", + group_desc=f"HYBRID_DP_CP_GROUP_{group_size}", ) if rank in ranks[i:i+group_size]: - assert group_size not in hybrid_cp_groups, f"Rank {rank} appears in multiple Hybrid CP groups of size {group_size}" - hybrid_cp_groups[group_size] = group - return hybrid_cp_groups + assert group_size not in hybrid_dp_cp_groups, f"Rank {rank} appears in multiple Hybrid DP CP groups of size {group_size}" + hybrid_dp_cp_groups[group_size] = group + return hybrid_dp_cp_groups class RankGenerator(object): """A class for generating rank groups for different modes of parallelism.""" @@ -819,6 +819,13 @@ def initialize_model_parallel( # Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups os.environ["NCCL_COLLNET_ENABLE"] = "0" + if hybrid_context_parallel: + assert len(ranks_with_cp) % 2 == 0, "Hybrid context parallel requires an even number of ranks" + global _HYBRID_DP_CP_GROUPS + if rank in ranks_with_cp: + _HYBRID_DP_CP_GROUPS.update(create_hybrid_dp_cp_groups(rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs))) + #TODO: Are gloo groups needed for hybrid cp? + for ranks in decoder_rank_generator.get_ranks('dp'): group = create_group( ranks, @@ -866,13 +873,6 @@ def initialize_model_parallel( if rank in ranks: _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups - if hybrid_context_parallel: - assert len(ranks) % 2 == 0, "Hybrid context parallel requires an even number of ranks" - global _HYBRID_CP_GROUPS - if rank in ranks: - _HYBRID_CP_GROUPS.update(create_hybrid_cp_groups(rank, ranks, get_nccl_options("cp", nccl_comm_cfgs))) - #TODO: Are gloo groups needed for hybrid cp? - # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GLOBAL_RANKS @@ -1322,16 +1322,16 @@ def get_hierarchical_context_parallel_groups(check_initialized=True): assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS -def get_hybrid_context_parallel_groups(check_initialized=True, group_size=None): +def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=None): """Get the hybrid context parallel groups the caller rank belongs to.""" # If the group size is the same as the entire DPxCP group, return the original group - if get_context_parallel_world_size() == group_size: + if get_data_parallel_world_size(with_context_parallel=True) == group_size: if check_initialized: - assert _CONTEXT_PARALLEL_GROUP is not None - return _CONTEXT_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP_WITH_CP is not None + return _DATA_PARALLEL_GROUP_WITH_CP if check_initialized: - assert _HYBRID_CP_GROUPS is not None - return _HYBRID_CP_GROUPS[group_size] + assert _HYBRID_DP_CP_GROUPS is not None + return _HYBRID_DP_CP_GROUPS[group_size] def get_embedding_group(check_initialized=True): """Get the embedding group the caller rank belongs to.""" diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 3d63a43a875..e01ac0d9f2f 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -8,20 +8,50 @@ from math import ceil, log2 import heapq -class HybridCPWrapper(): +class HybridCPDatasetWrapper(): """ - A wrapper class that wraps around any existing dataset and prints samples - as they are requested from the original dataset. + A wrapper class that wraps around any existing dataset. + It adds batch_idx to the sample and returns the sample. - This wrapper implements the standard PyTorch Dataset interface and can be - used with any dataset that follows the same interface. + This is required if certains DP ranks need the dataset idx + to access the appropriate sub-sample assigned to them from + the file system even if it was not originally assigned to them. Args: dataset: The original dataset to wrap around - print_format: Format string for printing samples (default: "Sample {idx}: {sample}") - print_func: Function to use for printing (default: print) - max_print_length: Maximum length of sample to print (default: 200 chars) - print_every: Print every Nth sample (default: 1, print all) + """ + + def __init__( + self, + dataset, + ): + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx) -> Any: + """ + Get the next item from the dataset, add batch_idx and return it. + """ + sample = self.dataset[idx] + sample["batch_idx"] = idx + assert "cu_seqlens" in sample, "cu_seqlens must be in the sample to use hybrid context parallel" + return sample + +class HybridCPDataLoaderWrapper(): + """ + A wrapper class that wraps around an existing data_iterator. + For every __next__ call, + 1. Each DP rank pulls a batch of packed samples. + 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. + 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. + 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. + 5. Returns the assigned sub-samples to this rank. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_cp_rank """ def __init__( @@ -33,23 +63,172 @@ def __init__( self.sample_count = 0 self.config = config self.cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank) + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size(with_context_parallel=True) def __iter__(self): """Return self as an iterator.""" return self - + + def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: + # Collect the number of subsamples from all ranks + local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=parallel_state.get_data_parallel_group()) + + # Find the maximum number of subsamples across all ranks and pad the subsample_seqlens to the max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if local_len.item() < max_sub_samples: + subsample_seqlens_padded = torch.cat([subsample_seqlens, torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda()], dim=0) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=parallel_state.get_data_parallel_group()) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[:dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) + + return seqlens_gathered, offsets + + def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): + # Calculate the global ID for each subsample + dp_rank = parallel_state.get_data_parallel_rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[offsets[dp_rank]:offsets[dp_rank] + num_local_subsamples] + + return global_id_seqlens, global_ids_this_rank + + def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:]-1) + hdp_rank = torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group())[dp_src_rank] + return hdp_rank + + def reroute_samples_to_hdp_ranks(self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets): + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + dp_ranks = torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group()) + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] + + for d in range(self.total_hdp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + + for dest_rank in range(self.total_hdp_gpus): + combined_sample_id_groups[dest_rank].sort() + + # Filter out samples that are not present on this rank + send_ids_sorted = [gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank] + # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + send_lens_split = [0] * self.total_hdp_gpus + for dest_rank in range(self.total_hdp_gpus): + if dest_rank in dp_ranks: + send_lens_split[dest_rank] = sum([global_id_seqlens[gid][1] for gid in combined_sample_id_groups[dest_rank] if gid in global_ids_this_rank]) + else: + # We only need to share local data with DP ranks that have different data. + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] + for gid in combined_sample_id_groups[hdp_rank]: + src_rank = self._gid_to_src_rank(gid, offsets) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * self.total_hdp_gpus + for src_rank in range(self.total_hdp_gpus): + recv_lens_split[src_rank] = sum([global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]]) + + # print(f"rank: {torch.distributed.get_rank()}, recv_lens_split: {recv_lens_split} send_lens_split: {send_lens_split}") + + recv_ids_sorted = [gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t) + return torch.cat(flattened_tensors, dim=0) if flattened_tensors else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = global_id_seqlens[gid][1] + recv_samples[i][key] = recv_tensor[cursor:cursor+sample_len] + cursor += sample_len + + for key in data_keys: + send_tensor = _pack_sample_by_key(key) + recv_tensor = torch.empty(sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=recv_lens_split, + input_split_sizes=send_lens_split, + group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + return recv_sample_with_id + + def unpack_batch(self, batch): + batch_unpacked = [] + for sample in batch: + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + for key in sample.keys(): + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: + continue + sub_sample_dict[key] = sample[key][start_idx:end_idx] + batch_unpacked.append(sub_sample_dict) + return batch_unpacked + def __next__(self) -> Any: """ Get the next item from the dataset, pull scheduling metadata and return it. """ - sample = next(self.data_iterator) - assert "cu_seqlens" in sample, "cu_seqlens must be in the sample" - # TODO(milestone 2): Get cu_seqlens and all-gather the entire global batch worth cu_seqlens and then perform the scheduling. - # But why should this scheduling information be integrated back into the data? - groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(sample, self.config) - sample["groups"] = groups - sample["sample_id_groups"] = sample_id_groups - return sample + batch = next(self.data_iterator) + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend([int(sample["cu_seqlens"][i+1] - sample["cu_seqlens"][i]) for i in range(0, sample["cu_seqlens"].shape[0] - 1)]) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + + seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) + + global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens(subsample_seqlens.shape[0], offsets, seqlens_gathered) + # global_id_seqlens = sorted(global_id_seqlens, key=lambda x: x[1], reverse=True) + + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(global_id_seqlens, self.config) + # sample["groups"] = groups + # sample["sample_id_groups"] = sample_id_groups + + batch = self.unpack_batch(batch) + samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks(batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets) + # for sample_id, sample in samples_this_rank_with_id.items(): + # sample["sample_id_groups"] + return samples_this_rank_with_id, sample_id_groups class BalancedCPScheduler: def __init__(self, max_seq_len_per_rank: int): @@ -57,6 +236,7 @@ def __init__(self, max_seq_len_per_rank: int): self.num_subsamples = 0 self.num_subsamples_processed = 0 self.free_resources = [] + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size(with_context_parallel=True) @lru_cache(maxsize=128) def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): @@ -428,29 +608,21 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members def get_groups_and_subsamples( self, - data, + sample_id_seqlens, config, ): # TODO: Protect for model parallelism # TODO: Reduce access to file system as much as possible. groups = [] sample_id_groups = [] - assert "cu_seqlens" in data, ( - "data must have a cu_seqlens attribute to define the valid sequenece lengths " - "of each sub-sample in a packed sample to use hybrid context parallel" - ) # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. - # TODO (Milestone 2): Sample ID logic will have to change once we have global batch - sample_id_seqlens = [(i, int(data["cu_seqlens"][0][i+1] - data["cu_seqlens"][0][i])) for i in range(0, data["cu_seqlens"][0].shape[0] - 1)] sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) while sample_id_seqlens: - mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group(sample_id_seqlens, self.get_total_workload, config.context_parallel_size) + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group(sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus) groups.append(mb) - if len(sample_ids) < config.context_parallel_size: - sample_ids.extend([] * (config.context_parallel_size - len(sample_ids))) + if len(sample_ids) < self.total_hdp_gpus: + sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) sample_id_groups.append(sample_ids) - # print(f"groups: {groups}") - # print(f"sample_id_groups: {sample_id_groups}") return groups, sample_id_groups @@ -495,67 +667,30 @@ def hybrid_context_parallel_forward_backward( from .schedules import forward_step, backward_step cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) + # We get data once per global batch and schedule the sub-samples. + # TODO(pmannan): We will change from 3 for loop to 2 for loops. + # groups and sample_id_groups will be provided by the scheduler. + # sample_id_groups will now directly be a list of sub-samples per group and we iterate over it. + # data will be a list of sub-samples with sample_id. + # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? + data = next(data_iterator) + sample_id_groups = data[1] + batch = data[0] + # TODO: How does this variable affect downstream logic? + num_microbatches = 1 + + # Upto last group, we don't need any sync. with no_sync_func(): - for i in range(num_microbatches - 1): - data = next(data_iterator) - # groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) - groups = data["groups"] - sample_id_groups = data["sample_id_groups"] - for j in range(len(groups)): - # Get sub-samples for the current CP rank - # TODO: Update to DPxCP rank when milestone 2 - sample_ids_per_group = sample_id_groups[j][parallel_state.get_context_parallel_rank()] - for k in range(len(sample_ids_per_group)): - # Call forward step for each sub-sample - sub_sample_id = sample_ids_per_group[k] - partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) - if partner_cp_size == 0: - assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([data])) - # TODO: Change data iterator to the right sub-sample - # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), - current_microbatch=i, - ) - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - # TODO: Move to DPxCP barrier - torch.distributed.barrier(parallel_state.get_context_parallel_group()) - - # Last microbatch - # TODO: Call scheduler here. - with no_sync_func(): - data = next(data_iterator) - # groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(data, model, config) - groups = data["groups"] - sample_id_groups = data["sample_id_groups"] - for j in range(len(groups) - 1): - sample_ids_per_group = sample_id_groups[j][parallel_state.get_context_parallel_rank()] - for k in range(len(sample_ids_per_group)): + for j in range(len(sample_id_groups) - 1): + sample_ids_this_group = sample_id_groups[j][parallel_state.get_data_parallel_rank(with_context_parallel=True)] + for sub_sample_id in sample_ids_this_group: # Call forward step for each sub-sample - sub_sample_id = sample_ids_per_group[k] - + sample = batch[sub_sample_id] partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) - if partner_cp_size == 0: - assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) - # TODO: What else should I update in data so that we can get the right sub-sample? - new_data_iterator = RerunDataIterator(iter([data])) - # TODO: Change data iterator to the right sub-sample + assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[j]}" + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. output_tensor, num_tokens = forward_step( forward_step_func, @@ -573,22 +708,22 @@ def hybrid_context_parallel_forward_backward( if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - # TODO: Move to DPxCP barrier - torch.distributed.barrier(parallel_state.get_context_parallel_group()) + # Create a barrier at end of each group. + # This barrier ensures that all ranks are prepared to change assigned CP group sizes and + # no rank is starting a sub-sample ahead of it's partner ranks. + torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) # For the last group, we need to run the last sub-sample out of the context handler. - # TODO: Find num sub-samples per group in this group + # TODO: Find num sub-samples per group in this group? with no_sync_func(): - sample_ids_per_group = sample_id_groups[-1][parallel_state.get_context_parallel_rank()] - for k in range(len(sample_ids_per_group) - 1): - sub_sample_id = sample_ids_per_group[k] + sample_ids_this_group = sample_id_groups[-1][parallel_state.get_data_parallel_rank(with_context_parallel=True)] + for k in range(len(sample_ids_this_group) - 1): + sub_sample_id = sample_ids_this_group[k] + sample = batch[sub_sample_id] partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) - data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) - # TODO: What else should I update in data so that we can get the right sub-sample? - new_data_iterator = RerunDataIterator(iter([data])) - # TODO: Change data iterator to the right sub-sample - # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, @@ -607,16 +742,13 @@ def hybrid_context_parallel_forward_backward( backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) # The last sub-sample of the last group of the last microbatch is run out of the context handler. - sub_sample_id = sample_ids_per_group[-1] + sub_sample_id = sample_ids_this_group[-1] partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) - if partner_cp_size == 0: - assert False, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} j: {j} k: {k} sample_ids_group: {sample_id_groups}" - data["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - data["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) - # TODO: What else should I update in data so that we can get the right sub-sample? - new_data_iterator = RerunDataIterator(iter([data])) - # TODO: Change data iterator to the right sub-sample - # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[-1]}" + sample = batch[sub_sample_id] + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, @@ -634,10 +766,11 @@ def hybrid_context_parallel_forward_backward( if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - # TODO: Move to DPxCP barrier - torch.distributed.barrier(parallel_state.get_context_parallel_group()) + torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) - # TODO: Before returning forward_data_store, we need to change the loss. - # Instead of letting reporting loss be calculated by train_step, can we just calculate it here? - # Since we will need the global number of samples information to average the loss and local cp group size for each sample. + # TODO: Before returning forward_data_store, do we need to change the loss? + # If loss calculation is done as sum(loss_per_token) / sum(total_tokens_per_sample), + # we don't need to change the loss. + # But if the loss calculation is different, then the user needs to define a new loss function + # for hybrid context parallel in their training script. return forward_data_store, total_num_tokens \ No newline at end of file diff --git a/megatron/legacy/data/data_samplers.py b/megatron/legacy/data/data_samplers.py index 78c7e1af410..b73cccb98d3 100644 --- a/megatron/legacy/data/data_samplers.py +++ b/megatron/legacy/data/data_samplers.py @@ -10,6 +10,22 @@ from megatron.training import get_args from megatron.core import mpu +# class HybridCPCollator: +# def __init__(self): +# pass + +# def __call__(self, batch: List[Dict[str, Any]]): +# # IF WE WANT TO COLLATE BEFORE SCHEDULER +# # We get a list of samples when pulling multiple microbatches at once for scheduling with Hybrid DPxCP. +# # This would require us to change the scheduler and the data loader logic to support [1, S] instead of [S]. +# # if isinstance(batch, list): +# # for idx, sample in enumerate(batch): +# # collated_sample = torch.utils.data.default_collate([sample]) +# # batch[idx] = collated_sample +# # else: +# # batch = torch.utils.data.default_collate(batch) + +# return batch def build_pretraining_data_loader(dataset, consumed_samples): """Build dataloader given an input dataset.""" @@ -20,12 +36,21 @@ def build_pretraining_data_loader(dataset, consumed_samples): # Megatron sampler if args.dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) + if args.hybrid_context_parallel: + batch_sampler = HybridCPMegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + else: + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, @@ -44,11 +69,16 @@ def build_pretraining_data_loader(dataset, consumed_samples): args.dataloader_type)) # Torch dataloader. + if args.hybrid_context_parallel: + extra_kwargs = {"collate_fn": lambda x: x,} + else: + extra_kwargs = {} return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, persistent_workers=True if args.num_workers > 0 else False, + **extra_kwargs, ) class MegatronPretrainingSampler: @@ -99,6 +129,43 @@ def __iter__(self): start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] +class HybridCPMegatronPretrainingSampler(MegatronPretrainingSampler): + + def __init__(self, total_samples, consumed_samples, micro_batch_size, global_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + super().__init__(total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last) + self.global_batch_size = global_batch_size + self.data_parallel_size = data_parallel_size + self.num_micro_batches = self.global_batch_size // self.micro_batch_times_data_parallel_size + + def __len__(self): + return self.total_samples + + def get_start_end_idx_global_batch(self): + start_idx = [self.data_parallel_rank * self.micro_batch_size + i * self.micro_batch_size * self.data_parallel_size for i in range(self.num_micro_batches)] + end_idx = [start_idx[i] + self.micro_batch_size for i in range(self.num_micro_batches)] + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size * self.num_micro_batches: + start_idx, end_idx = self.get_start_end_idx_global_batch() + global_batch_idx = [] + for i in range(self.num_micro_batches): + global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) + yield global_batch_idx + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx_global_batch() + global_batch_idx = [] + for i in range(self.num_micro_batches): + global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) + yield global_batch_idx class RandomSeedDataset(Dataset): diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 5f37a4c5b7e..64fc15a0b85 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -124,7 +124,8 @@ def extend_with_padding(tokens, targets, positions, pad_len): pack_positions.extend(range(len(tokens_list))) if self.config.context_parallel_size > 1: - pad_granularity = self.config.context_parallel_size * 2 + # TODO(pmannan): This is a hack to pad for Hybrid DPxCP. + pad_granularity = self.config.context_parallel_size * self.config.data_parallel_size * 2 mod_token_count = len(pack_tokens) % pad_granularity if mod_token_count != 0: pad_len = pad_granularity - mod_token_count diff --git a/megatron/training/training.py b/megatron/training/training.py index 6989f1d2f72..0498b7731d8 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -96,7 +96,7 @@ get_pp_rank_microbatches, get_schedule_table, ) -from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPWrapper +from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPDataLoaderWrapper from megatron.core.num_microbatches_calculator import ( destroy_num_microbatches_calculator, get_current_global_batch_size, @@ -2073,7 +2073,7 @@ def train( one_logger = get_one_logger() if args.hybrid_context_parallel: - train_data_iterator = HybridCPWrapper(train_data_iterator, config) + train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config)) if args.run_workload_inspector_server: try: diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 64f7f392eda..3060976ecf7 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -21,6 +21,7 @@ get_context_parallel_rank, get_context_parallel_world_size, ) +from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPDatasetWrapper from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.enums import ModelType @@ -249,7 +250,7 @@ def get_batch(data_iterator): else: packed_seq_params = None - if cu_seqlens is None: + if cu_seqlens is None and local_cp_size is None: # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore elif local_cp_size is None: # Packed THD format @@ -284,15 +285,37 @@ def get_batch(data_iterator): else: # Hybrid CP format assert local_cp_size is not None assert scheduled_id is not None - batch, cp_group, packed_seq_params = get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params) - packed_seq_params.cp_group = cp_group + # batch, cp_group, packed_seq_params = get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params) + if local_cp_size > 1: + cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) + else: + cp_group = None + + # Convert [seqlen] to [1, seqlen] similar to default collate_fn + for key, data in batch.items(): + if key in ['attention_mask']: + continue + batch[key] = torch.stack([data], 0) + sample_length = batch['tokens'].shape[1] + # Create packed_seq_params for SBHD format with cp group information. + # TODO(pmannan): Since entire PackedSeqParams is not needed, should we create a new dataclass with our information? + # We will need to update the logic in extensions/transformer_engine to support this. + # Piping through a new dataclass from training script might be adding extra overhead. + # Take ADLR recommendation on this. + packed_seq_params = PackedSeqParams( + qkv_format="sbhd", + cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + max_seqlen_q=sample_length, + max_seqlen_kv=sample_length, + local_cp_size=local_cp_size, + cp_group=cp_group, + ) + if cp_group is not None and cp_group.size() > 1: - # print(f"rank: {torch.distributed.get_rank()} has partner ranks {torch.distributed.get_process_group_ranks(cp_group)} and got keys: {batch.keys()}") batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) - # print(f"rank: {torch.distributed.get_rank()} has partner ranks {torch.distributed.get_process_group_ranks(cp_group)} and got tokens: {batch['tokens'].shape}") - # else: - # print(f"rank: {torch.distributed.get_rank()} has no partner ranks and got tokens: {batch['tokens'].shape}") - return (*batch.values(), packed_seq_params) @@ -434,6 +457,7 @@ def core_gpt_dataset_config_from_args(args): object_storage_cache_path=args.object_storage_cache_path, mid_level_dataset_surplus=args.mid_level_dataset_surplus, context_parallel_size=args.context_parallel_size, + data_parallel_size=args.data_parallel_size, ) @@ -461,6 +485,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config ).build() + # train_ds = HybridCPDatasetWrapper(train_ds) + print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds From 3b873855c1aa0733f56d8a6f7684370177a764ee Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 15 Sep 2025 16:03:45 -0700 Subject: [PATCH 19/81] Fix for 0 seq len in data --- examples/run_hybrid_cp.sh | 23 ++++++++++--------- .../pipeline_parallel/hybrid_cp_schedule.py | 3 +++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh index 7123da2496c..d792878b599 100755 --- a/examples/run_hybrid_cp.sh +++ b/examples/run_hybrid_cp.sh @@ -8,10 +8,10 @@ #SBATCH -t 00:30:00 #SBATCH --mem=0 #SBATCH --ntasks-per-node=8 -#SBATCH --nodes=4 +#SBATCH --nodes=1 #SBATCH --exclusive #SBATCH --gpus-per-node=8 -#SBATCH --job-name=video_sft_stage3_qwen_2p5_7b_radio_research_cp_0518_exp +#SBATCH --job-name=hetero_cp_global export NCCL_IB_SL=1 export TOKENIZERS_PARALLELISM="false" @@ -33,7 +33,7 @@ USE_TE_CE=0 USE_FLASH_ATTN=0 USE_FSDP=0 USE_CUSTOM_FSDP=0 -PROFILE=1 +PROFILE=0 # Remember to update model and job name if running in batch mode!! if [[ $BATCH -eq 0 ]]; then @@ -41,7 +41,7 @@ if [[ $BATCH -eq 0 ]]; then MODEL_NAME="interactive_hybrid_cp" DEBUG=1 else - MODEL_NAME="hybrid_cp_0821_exp" + MODEL_NAME="interactive_hybrid_cp" fi WORKSPACE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace" @@ -108,18 +108,17 @@ fi # CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" TP=1 -EXTRA_ARGS+=" --ckpt-format torch --use-distributed-optimizer " +EXTRA_ARGS+=" --ckpt-format torch_dist --use-distributed-optimizer " # EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " export CUDA_DEVICE_MAX_CONNECTIONS=1 EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " OPTIONS=" \ - --use-checkpoint-args \ --disable-bias-linear \ --sft \ --tokenizer-type SFTTokenizer \ - --sft-tokenizer-prompt-format nemotron-h-aligned + --sft-tokenizer-prompt-format nemotron-h-aligned \ --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ --transformer-impl transformer_engine \ --normalization RMSNorm \ @@ -134,7 +133,7 @@ OPTIONS=" \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size 1 \ --rerun-mode disabled \ - --num-layers 2 \ + --num-layers 4 \ --hidden-size 2048 \ --ffn-hidden-size 8192 \ --add-qkv-bias \ @@ -172,9 +171,11 @@ OPTIONS=" \ ${EXTRA_ARGS} \ --distributed-timeout-minutes 60 \ --calculate-per-token-loss \ + --attention-backend flash \ + --disable-gloo-process-groups \ + --use-dist-ckpt \ --hybrid-context-parallel \ --max-seqlen-per-cp-rank 16384 \ - --attention-backend flash \ " # --hybrid-context-parallel \ # --max-seqlen-per-cp-rank 16384 \ @@ -182,13 +183,13 @@ OPTIONS=" \ # Interactive or batch mode if [[ $BATCH -eq 0 ]]; then if [[ $PROFILE -eq 1 ]]; then - nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash_global_64 torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} else torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} fi else if [[ $PROFILE -eq 1 ]]; then - run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o stage3_iter3_4_tp2cp4_w_tpcp_shard_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" + run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o without_hetero_cp_global_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" else run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" fi diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index e01ac0d9f2f..631e74595d7 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -198,6 +198,8 @@ def unpack_batch(self, batch): sub_sample_dict = {} start_idx = sample["cu_seqlens"][sub_sample] end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue for key in sample.keys(): if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: continue @@ -214,6 +216,7 @@ def __next__(self) -> Any: for sample in batch: subsample_seqlens.extend([int(sample["cu_seqlens"][i+1] - sample["cu_seqlens"][i]) for i in range(0, sample["cu_seqlens"].shape[0] - 1)]) subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) From bd04cfe6a02c4043c0c81ee17eb3d36623372e7c Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 15 Sep 2025 17:50:44 -0700 Subject: [PATCH 20/81] Fix fill_empty_gpus workload assignment --- .../pipeline_parallel/hybrid_cp_schedule.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 631e74595d7..39947fd377e 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -548,6 +548,11 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members existing_group_sizes = set(group_size.values()) if not existing_group_sizes: return # No groups exist, cannot redistribute + + # TODO(pmannan): Should we assert that no GPU has more than 1 sub-sample assigned to it? + # Why would there be empty GPUs if some GPUs have multiple sub-samples assigned to it? + # But if it does happen if seq len forces it, I need to change the assumptions below. + # Below code assumes that there is 1 sub-sample in the GPUs that we will expand. min_group_size = min(existing_group_sizes) next_power = min_group_size * 2 @@ -560,34 +565,37 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members for gid, size in group_size.items(): if size == min_group_size: members = group_members[gid] - # get_new_work_queue(members[-1], min_group_size) - needed_count = min_group_size - current_gpu = members[-1] + needed_count = next_power - min_group_size + group_start_gpu = members[0] + group_end_gpu = members[-1] empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] assert not all(work for work in micro_batches[empty_gpu:empty_gpu+needed_count]), f"Not enough empty GPUs to expand or there are empty GPUs between work scheduled which is not allowed." - work_to_push = micro_batches[current_gpu + 1 : empty_gpu] # This is work of all other subsequent sub-samples - exec_times_to_push = exec_times[current_gpu + 1 : empty_gpu] - sample_ids_to_push = sample_ids_per_gpu[current_gpu + 1 : empty_gpu] + work_to_push = micro_batches[group_end_gpu + 1 : empty_gpu] # This is work of all other subsequent sub-samples + exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] new_micro_batches = [[]] * len(micro_batches) new_exec_times = [0.0] * len(exec_times) new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) - for i in range(current_gpu+1): + # No change in work until the group selected for expansion + for i in range(group_start_gpu): new_micro_batches[i] = micro_batches[i] new_exec_times[i] = exec_times[i] new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] - for i in range(needed_count): - new_micro_batches[current_gpu + 1 +i] = micro_batches[current_gpu] - new_exec_times[current_gpu + 1 + i] = exec_times[current_gpu] - new_sample_ids_per_gpu[current_gpu + 1 + i] = sample_ids_per_gpu[current_gpu] + # The work is distributed across the expanded group + for i in range(group_start_gpu, group_end_gpu + needed_count + 1): + new_micro_batches[i] = micro_batches[group_end_gpu] + new_exec_times[i] = self.get_total_workload(micro_batches[group_end_gpu][0], next_power) + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] + # Any assigned work on expanded GPUs is pushed for i, work in enumerate(work_to_push): - new_micro_batches[current_gpu + needed_count + 1 + i] = work - new_exec_times[current_gpu + needed_count + 1 + i] = exec_times_to_push[i] - new_sample_ids_per_gpu[current_gpu + needed_count + 1 + i] = sample_ids_to_push[i] + new_micro_batches[group_end_gpu + needed_count + 1 + i] = work + new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = sample_ids_to_push[i] group_size[gid] = next_power group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) From 160f51401b0bed958d0d55302194017f85d1d864 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 07:13:26 -0700 Subject: [PATCH 21/81] llama3 8B benchmarking script --- examples/run_llama38b_hybrid_cp.sh | 212 +++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100755 examples/run_llama38b_hybrid_cp.sh diff --git a/examples/run_llama38b_hybrid_cp.sh b/examples/run_llama38b_hybrid_cp.sh new file mode 100755 index 00000000000..03ee9dc553b --- /dev/null +++ b/examples/run_llama38b_hybrid_cp.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +#SBATCH -A coreai_dlalgo_genai +# DFW: batch +# OCI-NRT: batch_block1 +# OCI-IAD: batch_block1,batch_block3,batch_block4,backfill_block1,backfill_block2,backfill_block3,backfill_block4 +#SBATCH -p batch +#SBATCH -t 00:20:00 +#SBATCH --mem=0 +#SBATCH --ntasks-per-node=8 +#SBATCH --nodes=4 +#SBATCH --exclusive +#SBATCH --gpus-per-node=8 +#SBATCH --job-name=hetero_cp_global + +export NCCL_IB_SL=1 +export TOKENIZERS_PARALLELISM="false" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +#export NVTE_DEBUG=1 +#export NVTE_DEBUG_LEVEL=2 + +USER=$SLURM_JOB_USER + +# Auto-detect batch or interactive mode. +which srun +BATCH=$((1-$?)) + +DEBUG=0 +USE_TILING=1 +USE_CP=1 +USE_TE_CE=0 +USE_FLASH_ATTN=0 +USE_FSDP=0 +USE_CUSTOM_FSDP=0 +PROFILE=0 + +# Remember to update model and job name if running in batch mode!! +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="interactive_hybrid_cp" + DEBUG=1 +else + MODEL_NAME="interactive_hybrid_cp" +fi + +WORKSPACE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" +DATACACHE_DIR="${OUTPUT}/data_cache" + +export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache" + +DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl" + +SEQ_LEN=131072 #131072 #81920 #65536 + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=64 + NW=4 + AD=0.0 + HD=0.0 + LI=1 + + # EXTRA_ARGS="--deterministic-mode --use-cpu-initialization" + + NONDETERMINISTIC_ATTN=1 + + NUM_GPU=8 + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + #export NCCL_ALGO=Tree + #export CUBLAS_WORKSPACE_CONFIG=:4096:8 +else + MBZ=1 + BZ=64 + NW=8 + AD=0.0 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 + NUM_GPU=8 +fi + +if [[ $USE_CP -eq 1 ]]; then + if [[ $BATCH -eq 1 ]]; then + CP_SIZE=8 + else + CP_SIZE=8 + fi + EXTRA_ARGS+=" --context-parallel-size ${CP_SIZE} " +fi + +if [[ $USE_TE_CE -eq 1 ]]; then + EXTRA_ARGS+=" --enable-te-ce --cross-entropy-loss-fusion " +fi + +if [[ $PROFILE -eq 1 ]]; then + EXTRA_ARGS+="--profile --profile-step-start 7 --profile-step-end 8 --profile-ranks 0 16 " +fi + +# CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" +TP=1 +EXTRA_ARGS+=" --ckpt-format torch_dist --use-distributed-optimizer " +# EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " + +OPTIONS=" \ + --disable-bias-linear \ + --sft \ + --tokenizer-type SFTTokenizer \ + --sft-tokenizer-prompt-format nemotron-h-aligned \ + --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size 1 \ + --rerun-mode disabled \ + --num-layers 28 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --add-qkv-bias \ + --num-attention-heads 32 \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples 100000 \ + --lr-warmup-samples 20000 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-5 \ + --min-lr 0.0 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 999999 \ + --data-path ${DATA_TRAIN} \ + --save-interval 1000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --data-cache-path ${DATACACHE_DIR} \ + --use-mcore-models \ + --no-create-attention-mask-in-dataloader \ + --no-mmap-bin-files \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 0.05 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --init-method-std 0.014 \ + --bf16 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --calculate-per-token-loss \ + --attention-backend flash \ + --group-query-attention \ + --num-query-groups 8 \ + --disable-gloo-process-groups \ + --use-dist-ckpt \ + --hybrid-context-parallel \ + --max-seqlen-per-cp-rank 8192 \ +" + +# --recompute-method block \ +# --recompute-num-layers 14 \ +# --recompute-granularity full \ + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + if [[ $PROFILE -eq 1 ]]; then + nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash_global_64 torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + else + torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} + fi +else + if [[ $PROFILE -eq 1 ]]; then + run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o llama38b_hybrid_cp_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" + else + run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" + fi + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + echo "run_cmd: ${run_cmd}" + + srun -l --verbose \ + --container-image /lustre/fsw/portfolios/llmservice/users/matthieul/docker/megatron-dev-img-05142025-pytorch-dev-te-cd37379-editable-energon-mamba-fix-vlmeval-pad-conv.sqsh \ + --container-mounts "/lustre" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi \ No newline at end of file From b0eb325f5fa1cd2ce66772ce3ef38772cf5a380a Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 07:32:07 -0700 Subject: [PATCH 22/81] Bug fix --- megatron/core/pipeline_parallel/hybrid_cp_schedule.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 39947fd377e..047b5189a74 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -549,13 +549,9 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members if not existing_group_sizes: return # No groups exist, cannot redistribute - # TODO(pmannan): Should we assert that no GPU has more than 1 sub-sample assigned to it? - # Why would there be empty GPUs if some GPUs have multiple sub-samples assigned to it? - # But if it does happen if seq len forces it, I need to change the assumptions below. - # Below code assumes that there is 1 sub-sample in the GPUs that we will expand. - min_group_size = min(existing_group_sizes) - next_power = min_group_size * 2 + # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. + next_power = min(min_group_size * 2, total_gpus) # Find the first group of min_group_size that can be expanded expandable_gid = None From 7e9fe9045deaa930d75dda29158282eec0668ce5 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 08:28:04 -0700 Subject: [PATCH 23/81] Revert "ADLR/megatron-lm!3963 - SFT chat template and pad token changes for Nano v2 12B" This reverts commit 746c91341efbb8b31dfbe2bac1801ea6435af28b. --- megatron/training/tokenizer/sft_tokenizer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/training/tokenizer/sft_tokenizer.py b/megatron/training/tokenizer/sft_tokenizer.py index c4a7a5ca318..99024ab6a45 100644 --- a/megatron/training/tokenizer/sft_tokenizer.py +++ b/megatron/training/tokenizer/sft_tokenizer.py @@ -4,7 +4,7 @@ from typing import Dict, List, Union import numpy as np -nemotron_nano_v2_custom_template = """{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'system' %}{{ 'System\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + content.strip() + '\n\n' }}{% endif %}{% endfor %}""" +nemotron_h_aligned_custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ 'System\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' + 'Assistant\n' }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + '\n' }}{% endif %}{% endfor %}""" from megatron.core.datasets.megatron_tokenizer import MegatronLegacyTokenizer from megatron.training.datasets.sft_dataset import IGNORE_INDEX @@ -41,11 +41,11 @@ def __init__( self._vocab_size = len(tokenizer) self._tokenizer = tokenizer - if prompt_format == "nemotron-nano-v2": + if prompt_format == "nemotron-h-aligned": self._prompt_config = PromptConfig( - assistant_prefix_len=3, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_nano_v2_custom_template, + assistant_prefix_len=0, + pad_token_id=tokenizer.convert_tokens_to_ids(""), + custom_chat_template=nemotron_h_aligned_custom_template, has_bos=False, has_system_role=True, ) From b635d159d7f066e7869742e031c7b0fbc17f1778 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 09:08:48 -0700 Subject: [PATCH 24/81] Turn on legacy tokenizer --- examples/run_llama38b_hybrid_cp.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/run_llama38b_hybrid_cp.sh b/examples/run_llama38b_hybrid_cp.sh index 03ee9dc553b..a8d618da5aa 100755 --- a/examples/run_llama38b_hybrid_cp.sh +++ b/examples/run_llama38b_hybrid_cp.sh @@ -117,6 +117,7 @@ EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-par OPTIONS=" \ --disable-bias-linear \ --sft \ + --legacy-tokenizer \ --tokenizer-type SFTTokenizer \ --sft-tokenizer-prompt-format nemotron-h-aligned \ --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ From dc8139f2ddd22bb1a70ded90f33d2280b6f8c48e Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Tue, 16 Sep 2025 16:11:40 +0000 Subject: [PATCH 25/81] chore: Format files --- .../common/embeddings/rotary_pos_embedding.py | 6 +- megatron/core/models/gpt/gpt_model.py | 3 +- megatron/core/packed_seq_params.py | 4 +- megatron/core/parallel_state.py | 26 +- .../pipeline_parallel/hybrid_cp_schedule.py | 418 ++++++++++++------ megatron/core/utils.py | 55 ++- 6 files changed, 345 insertions(+), 167 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 72b55caf9ff..2306f43631c 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -175,7 +175,9 @@ def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: emb = emb[:, None, None, :] return emb - def forward(self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: + def forward( + self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + ) -> Tensor: """Forward pass of RoPE embedding. Args: @@ -199,7 +201,7 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) - + return emb def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d5a7914af5f..eb01026dc88 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -301,8 +301,7 @@ def _preprocess( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq_params=packed_seq_params, + rotary_seq_len, packed_seq_params=packed_seq_params ) elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: if self.training or not self.config.flash_decode: diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index ea5e9463187..fe09764d241 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -1,9 +1,9 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass - -from torch import Tensor from typing import List + import torch.distributed as dist +from torch import Tensor @dataclass diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index f37fc7b59f8..87e92b8d3e7 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -366,6 +366,7 @@ def create_hierarchical_groups( assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) return hierarchical_groups, hierarchical_groups_gloo + def create_hybrid_dp_cp_groups(rank, ranks, pg_options): hybrid_dp_cp_groups = {} # Generate group for every power of 2 up to the number of CP ranks @@ -374,15 +375,18 @@ def create_hybrid_dp_cp_groups(rank, ranks, pg_options): for group_size in group_sizes: for i in range(0, len(ranks), group_size): group = create_group( - ranks[i:i+group_size], + ranks[i : i + group_size], pg_options=pg_options, group_desc=f"HYBRID_DP_CP_GROUP_{group_size}", ) - if rank in ranks[i:i+group_size]: - assert group_size not in hybrid_dp_cp_groups, f"Rank {rank} appears in multiple Hybrid DP CP groups of size {group_size}" + if rank in ranks[i : i + group_size]: + assert ( + group_size not in hybrid_dp_cp_groups + ), f"Rank {rank} appears in multiple Hybrid DP CP groups of size {group_size}" hybrid_dp_cp_groups[group_size] = group return hybrid_dp_cp_groups + class RankGenerator(object): """A class for generating rank groups for different modes of parallelism.""" @@ -856,12 +860,18 @@ def initialize_model_parallel( del os.environ["NCCL_COLLNET_ENABLE"] if hybrid_context_parallel: - assert len(ranks_with_cp) % 2 == 0, "Hybrid context parallel requires an even number of ranks" + assert ( + len(ranks_with_cp) % 2 == 0 + ), "Hybrid context parallel requires an even number of ranks" global _HYBRID_DP_CP_GROUPS if rank in ranks_with_cp: - _HYBRID_DP_CP_GROUPS.update(create_hybrid_dp_cp_groups(rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs))) - #TODO: Are gloo groups needed for hybrid cp? - + _HYBRID_DP_CP_GROUPS.update( + create_hybrid_dp_cp_groups( + rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs) + ) + ) + # TODO: Are gloo groups needed for hybrid cp? + for ranks in decoder_rank_generator.get_ranks('dp'): group = create_group( ranks, @@ -1375,6 +1385,7 @@ def get_hierarchical_context_parallel_groups(check_initialized=True): assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS + def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=None): """Get the hybrid context parallel groups the caller rank belongs to.""" # If the group size is the same as the entire DPxCP group, return the original group @@ -1386,6 +1397,7 @@ def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=N assert _HYBRID_DP_CP_GROUPS is not None return _HYBRID_DP_CP_GROUPS[group_size] + def get_embedding_group(check_initialized=True): """Get the embedding group the caller rank belongs to.""" if check_initialized: diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 047b5189a74..1e4e809bf19 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -1,30 +1,30 @@ +import heapq +from collections import deque +from functools import lru_cache +from math import ceil, log2 +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + import torch + from megatron.core import parallel_state from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.rerun_state_machine import RerunDataIterator -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from functools import lru_cache -from collections import deque -from math import ceil, log2 -import heapq -class HybridCPDatasetWrapper(): + +class HybridCPDatasetWrapper: """ A wrapper class that wraps around any existing dataset. It adds batch_idx to the sample and returns the sample. - + This is required if certains DP ranks need the dataset idx to access the appropriate sub-sample assigned to them from the file system even if it was not originally assigned to them. - + Args: dataset: The original dataset to wrap around """ - - def __init__( - self, - dataset, - ): + + def __init__(self, dataset): self.dataset = dataset def __len__(self): @@ -36,61 +36,82 @@ def __getitem__(self, idx) -> Any: """ sample = self.dataset[idx] sample["batch_idx"] = idx - assert "cu_seqlens" in sample, "cu_seqlens must be in the sample to use hybrid context parallel" + assert ( + "cu_seqlens" in sample + ), "cu_seqlens must be in the sample to use hybrid context parallel" return sample -class HybridCPDataLoaderWrapper(): + +class HybridCPDataLoaderWrapper: """ A wrapper class that wraps around an existing data_iterator. - For every __next__ call, + For every __next__ call, 1. Each DP rank pulls a batch of packed samples. 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. 5. Returns the assigned sub-samples to this rank. - + Args: data_iterator: The original data_iterator to wrap around config: The config object containing the max_seqlen_per_cp_rank """ - - def __init__( - self, - data_iterator, - config, - ): + + def __init__(self, data_iterator, config): self.data_iterator = data_iterator self.sample_count = 0 self.config = config - self.cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank) - self.total_hdp_gpus = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + self.cp_balancing_scheduler = BalancedCPScheduler( + max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank + ) + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) def __iter__(self): """Return self as an iterator.""" return self - + def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: # Collect the number of subsamples from all ranks local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() - dp_subsample_count = [torch.zeros_like(local_len) for _ in range(parallel_state.get_data_parallel_world_size())] - torch.distributed.all_gather(dp_subsample_count, local_len, group=parallel_state.get_data_parallel_group()) - + dp_subsample_count = [ + torch.zeros_like(local_len) + for _ in range(parallel_state.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + dp_subsample_count, local_len, group=parallel_state.get_data_parallel_group() + ) + # Find the maximum number of subsamples across all ranks and pad the subsample_seqlens to the max length dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) max_sub_samples = int(dp_subsample_counts.max().item()) if local_len.item() < max_sub_samples: - subsample_seqlens_padded = torch.cat([subsample_seqlens, torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda()], dim=0) + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), + ], + dim=0, + ) else: subsample_seqlens_padded = subsample_seqlens # Gather the subsample_seqlens from all ranks - seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(parallel_state.get_data_parallel_world_size())] - torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=parallel_state.get_data_parallel_group()) - + seqlens_gathered = [ + torch.empty_like(subsample_seqlens_padded) + for _ in range(parallel_state.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + seqlens_gathered, + subsample_seqlens_padded, + group=parallel_state.get_data_parallel_group(), + ) + # Trim each seqlens_gathered to the length of the correct sample for dp_rank, seqlen in enumerate(seqlens_gathered): - seqlens_gathered[dp_rank] = seqlen[:dp_subsample_counts[dp_rank]] + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] seqlens_gathered = torch.cat(seqlens_gathered, dim=0) seqlens_gathered = seqlens_gathered.cpu().tolist() @@ -108,19 +129,27 @@ def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered) # Create a list of (global_id, seqlen) tuples for scheduling global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] # Get the global IDs locally present on this rank - global_ids_this_rank = global_ids[offsets[dp_rank]:offsets[dp_rank] + num_local_subsamples] + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] return global_id_seqlens, global_ids_this_rank def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: - dp_src_rank = torch.bucketize(gid, offsets[1:]-1) - hdp_rank = torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group())[dp_src_rank] + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + hdp_rank = torch.distributed.get_process_group_ranks( + parallel_state.get_data_parallel_group() + )[dp_src_rank] return hdp_rank - - def reroute_samples_to_hdp_ranks(self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets): + + def reroute_samples_to_hdp_ranks( + self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ): gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - dp_ranks = torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group()) + dp_ranks = torch.distributed.get_process_group_ranks( + parallel_state.get_data_parallel_group() + ) data_keys = batch[0].keys() # Create the send plan @@ -128,19 +157,30 @@ def reroute_samples_to_hdp_ranks(self, batch, global_ids_this_rank, global_id_se for d in range(self.total_hdp_gpus): for sample_id_group in sample_id_groups: - combined_sample_id_groups[d].extend(sample_id_group[d]) + combined_sample_id_groups[d].extend(sample_id_group[d]) for dest_rank in range(self.total_hdp_gpus): combined_sample_id_groups[dest_rank].sort() # Filter out samples that are not present on this rank - send_ids_sorted = [gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank] + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] send_lens_split = [0] * self.total_hdp_gpus for dest_rank in range(self.total_hdp_gpus): if dest_rank in dp_ranks: - send_lens_split[dest_rank] = sum([global_id_seqlens[gid][1] for gid in combined_sample_id_groups[dest_rank] if gid in global_ids_this_rank]) + send_lens_split[dest_rank] = sum( + [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + ) else: # We only need to share local data with DP ranks that have different data. send_lens_split[dest_rank] = 0 @@ -150,14 +190,18 @@ def reroute_samples_to_hdp_ranks(self, batch, global_ids_this_rank, global_id_se for gid in combined_sample_id_groups[hdp_rank]: src_rank = self._gid_to_src_rank(gid, offsets) recv_sample_id_groups[src_rank].append(gid) - + recv_lens_split = [0] * self.total_hdp_gpus for src_rank in range(self.total_hdp_gpus): - recv_lens_split[src_rank] = sum([global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]]) + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) # print(f"rank: {torch.distributed.get_rank()}, recv_lens_split: {recv_lens_split} send_lens_split: {send_lens_split}") - recv_ids_sorted = [gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d]] + recv_ids_sorted = [ + gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] + ] recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] @@ -167,18 +211,24 @@ def _pack_sample_by_key(key: str) -> torch.Tensor: for gid in send_ids_sorted: t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) flattened_tensors.append(t) - return torch.cat(flattened_tensors, dim=0) if flattened_tensors else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): cursor = 0 for i, gid in enumerate(recv_ids_sorted): sample_len = global_id_seqlens[gid][1] - recv_samples[i][key] = recv_tensor[cursor:cursor+sample_len] + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] cursor += sample_len - + for key in data_keys: send_tensor = _pack_sample_by_key(key) - recv_tensor = torch.empty(sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype) + recv_tensor = torch.empty( + sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) torch.distributed.all_to_all_single( output=recv_tensor, input=send_tensor, @@ -187,8 +237,10 @@ def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) _unpack_sample_by_key(key, recv_tensor) - - recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } return recv_sample_with_id def unpack_batch(self, batch): @@ -206,7 +258,7 @@ def unpack_batch(self, batch): sub_sample_dict[key] = sample[key][start_idx:end_idx] batch_unpacked.append(sub_sample_dict) return batch_unpacked - + def __next__(self) -> Any: """ Get the next item from the dataset, pull scheduling metadata and return it. @@ -214,32 +266,46 @@ def __next__(self) -> Any: batch = next(self.data_iterator) subsample_seqlens = [] for sample in batch: - subsample_seqlens.extend([int(sample["cu_seqlens"][i+1] - sample["cu_seqlens"][i]) for i in range(0, sample["cu_seqlens"].shape[0] - 1)]) + subsample_seqlens.extend( + [ + int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) + for i in range(0, sample["cu_seqlens"].shape[0] - 1) + ] + ) subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) - global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens(subsample_seqlens.shape[0], offsets, seqlens_gathered) + global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered + ) # global_id_seqlens = sorted(global_id_seqlens, key=lambda x: x[1], reverse=True) - groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(global_id_seqlens, self.config) + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( + global_id_seqlens, self.config + ) # sample["groups"] = groups # sample["sample_id_groups"] = sample_id_groups batch = self.unpack_batch(batch) - samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks(batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets) + samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( + batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ) # for sample_id, sample in samples_this_rank_with_id.items(): # sample["sample_id_groups"] return samples_this_rank_with_id, sample_id_groups + class BalancedCPScheduler: def __init__(self, max_seq_len_per_rank: int): self.max_seq_len_per_rank = max_seq_len_per_rank self.num_subsamples = 0 self.num_subsamples_processed = 0 self.free_resources = [] - self.total_hdp_gpus = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) @lru_cache(maxsize=128) def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): @@ -273,10 +339,10 @@ def make_buckets_equal( """ # Extract just the sequence lengths for determining k seqlens = [seq_len for _, seq_len in sample_seqlens] - + # Determine k based on unique GPU categories needed k = len({self.gpus_needed(L) for L in seqlens}) - + # Use the existing contiguous_equal_buckets function but with sample_seqlens # We need to modify it to work with tuples work = [] @@ -292,22 +358,23 @@ def make_buckets_equal( for i, (sample_id, seq_len) in enumerate(sample_seqlens): work = compute_estimator(seq_len) projected = cur_work + work - + # Check if we should close this bucket - if (cur and - (projected > target * 1.1 or # Too much work - len(sample_seqlens) - i <= remaining_k - len(buckets))): # Need to save sequences for remaining buckets + if cur and ( + projected > target * 1.1 # Too much work + or len(sample_seqlens) - i <= remaining_k - len(buckets) + ): # Need to save sequences for remaining buckets buckets.append(deque(cur)) cur, cur_work = [], 0.0 remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) remaining_k -= 1 - + cur.append((sample_id, seq_len)) cur_work += work - + if cur: buckets.append(deque(cur)) - + return buckets def next_hdp_group( @@ -315,9 +382,9 @@ def next_hdp_group( sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples compute_estimator: Callable[[int], float], total_gpus: int, - delta: float = 0.05, # balance slack (e.g. 5 %) - strategy: str = "dp", # "dp" or "pp" - eps_bucket: float = 0.10, # ε target for bucket balance + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: """ Given a list of (sample_id, sequence_length) tuples, this function aims to assign @@ -325,14 +392,14 @@ def next_hdp_group( Once each microbatch is roughly balanced, we exit and return the microbatch and the leftover sequences. The function performs the following passes in order to form a balanced microbatch: - 1. We create buckets of sequences that are roughly balanced. + 1. We create buckets of sequences that are roughly balanced. We try to create as many buckets as possible CP sizes. 2. Given a bucket has sequences available, we assign the microbatch a. To a new set of GPUs if there are enough free GPUs. b. To an existing set of GPUs with the lowest load. 3. We check if the microbatch is balanced whenever we need to move onto a new CP size in the same set of GPUs. 4. We trim the microbatch if removing the last added sequence helps improve balance. - 5. If we run out of sequences to assign and there are empty GPUs, + 5. If we run out of sequences to assign and there are empty GPUs, we redistribute work to empty GPUs by recursively increasing the CP size of a sample until no empty GPUs are left.. #TODO: Add clarification on when we check for balance. What does prev_needed do? @@ -340,32 +407,38 @@ def next_hdp_group( Returns (*micro_batches*, *leftover_sample_seqlens*, *exec_times*, *sample_ids_per_gpu*). """ if not sample_seqlens: - return [[] for _ in range(total_gpus)], [], [0.0 for _ in range(total_gpus)], [[] for _ in range(total_gpus)] + return ( + [[] for _ in range(total_gpus)], + [], + [0.0 for _ in range(total_gpus)], + [[] for _ in range(total_gpus)], + ) # Use the improved bucketing that works with (sample_id, seq_len) tuples buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) # Initialize tracking structures - micro_batches = [[] for _ in range(total_gpus)] - exec_times = [0.0 for _ in range(total_gpus)] + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] sample_ids_per_gpu = [[] for _ in range(total_gpus)] - gpu_group_id = [None] * total_gpus - group_members = {} - group_size = {} - next_gid = 0 + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 - pp_cursor = 0 - prev_needed = None - check_balance = False + pp_cursor = 0 + prev_needed = None + check_balance = False while buckets: # ---- Step 1 – pick the next sequence we COULD place ------------------ sample_seq_tuple = bucket_idx = None - needed = None + needed = None scan_order = ( - range(len(buckets)) if strategy == "dp" + range(len(buckets)) + if strategy == "dp" else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] ) @@ -398,14 +471,14 @@ def next_hdp_group( prev_needed = needed # (a) Existing groups of exactly this size - candidate_gids = [ - gid for gid, sz in group_size.items() if sz == needed - ] + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] if candidate_gids: best_gid, best_load = min( - ((gid, max(exec_times[r] for r in group_members[gid])) - for gid in candidate_gids), - key=lambda t: t[1] + ( + (gid, max(exec_times[r] for r in group_members[gid])) + for gid in candidate_gids + ), + key=lambda t: t[1], ) else: best_gid, best_load = None, float("inf") @@ -415,7 +488,7 @@ def next_hdp_group( if len(free_ranks) >= needed: free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) new_members = free_sorted[:needed] - new_load = exec_times[new_members[-1]] + new_load = exec_times[new_members[-1]] if new_load < best_load: best_gid = None @@ -432,13 +505,13 @@ def next_hdp_group( best_gid = next_gid next_gid += 1 group_members[best_gid] = chosen_members - group_size[best_gid] = needed + group_size[best_gid] = needed for r in chosen_members: gpu_group_id[r] = best_gid # ---- Step 3 – assign the sequence to every member of that group ------ per_gpu_cost = compute_estimator(seq_len) - + for r in chosen_members: micro_batches[r].append(seq_len) exec_times[r] += per_gpu_cost @@ -470,15 +543,19 @@ def next_hdp_group( # We keep assigning group of 2 as we do in descending order but GPU 7/15 never sees a microbatch assigned to it # until we run out of samples with CP2. # This means we are never balanced as min(exec_times) will always be 0. - # We need a smart way of identifying that we have run out of big samples and if we are having to + # We need a smart way of identifying that we have run out of big samples and if we are having to # assign work to a GPU already working, is it because there are empty GPUs? # Would assigning work to empty GPUs first by moving onto next CP bucket help? # But we need to remember to come back to this CP size bucket and then check for balance. - # Maybe the scheduling algorithm should look at empty GPUs and find work rather than going + # Maybe the scheduling algorithm should look at empty GPUs and find work rather than going # sequence by sequence. check_balance = True - if check_balance and buckets and max(exec_times) - min(exec_times) <= delta * max(exec_times): + if ( + check_balance + and buckets + and max(exec_times) - min(exec_times) <= delta * max(exec_times) + ): break # Gather leftovers (flatten remaining buckets, preserve order) @@ -486,7 +563,7 @@ def next_hdp_group( for b in buckets: for sample_seq_tuple in b: leftovers.append(sample_seq_tuple) - + # --------------------------------------------------------------------------- def trim_overload(): """ @@ -494,21 +571,21 @@ def trim_overload(): whenever doing so reduces the global slack. """ while True: - cur_max = max(exec_times) - cur_min = min(exec_times) + cur_max = max(exec_times) + cur_min = min(exec_times) cur_slack = cur_max - cur_min if cur_slack <= delta * cur_max: break - max_r = exec_times.index(cur_max) - gid = gpu_group_id[max_r] + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] members = group_members[gid] if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: break - seq = micro_batches[max_r][-1] - need = group_size[gid] + seq = micro_batches[max_r][-1] + need = group_size[gid] per_gpu_cost = compute_estimator(seq) proj_times = exec_times[:] @@ -533,7 +610,9 @@ def trim_overload(): total_work_before = sum(len(mb) for mb in micro_batches) # Check for empty GPUs and redistribute work - def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size): + def fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ): """ Recursively check for empty GPUs and redistribute work by increasing the number of GPUs sharing samples. This ensures all GPUs have work. @@ -542,8 +621,14 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members # Find empty GPUs empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] if not empty_gpus: - return micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size # No empty GPUs, we're done - + return ( + micro_batches, + exec_times, + sample_ids_per_gpu, + group_members, + group_size, + ) # No empty GPUs, we're done + # Find the smallest group size that exists existing_group_sizes = set(group_size.values()) if not existing_group_sizes: @@ -552,12 +637,12 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members min_group_size = min(existing_group_sizes) # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. next_power = min(min_group_size * 2, total_gpus) - + # Find the first group of min_group_size that can be expanded expandable_gid = None expandable_members = None expandable_new_gpus = None - + for gid, size in group_size.items(): if size == min_group_size: members = group_members[gid] @@ -565,11 +650,14 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members group_start_gpu = members[0] group_end_gpu = members[-1] empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] - assert not all(work for work in micro_batches[empty_gpu:empty_gpu+needed_count]), f"Not enough empty GPUs to expand or there are empty GPUs between work scheduled which is not allowed." - work_to_push = micro_batches[group_end_gpu + 1 : empty_gpu] # This is work of all other subsequent sub-samples + assert not all( + work for work in micro_batches[empty_gpu : empty_gpu + needed_count] + ), f"Not enough empty GPUs to expand or there are empty GPUs between work scheduled which is not allowed." + work_to_push = micro_batches[ + group_end_gpu + 1 : empty_gpu + ] # This is work of all other subsequent sub-samples exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] - new_micro_batches = [[]] * len(micro_batches) new_exec_times = [0.0] * len(exec_times) @@ -584,40 +672,53 @@ def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members # The work is distributed across the expanded group for i in range(group_start_gpu, group_end_gpu + needed_count + 1): new_micro_batches[i] = micro_batches[group_end_gpu] - new_exec_times[i] = self.get_total_workload(micro_batches[group_end_gpu][0], next_power) + new_exec_times[i] = self.get_total_workload( + micro_batches[group_end_gpu][0], next_power + ) new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] # Any assigned work on expanded GPUs is pushed for i, work in enumerate(work_to_push): new_micro_batches[group_end_gpu + needed_count + 1 + i] = work new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] - new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = sample_ids_to_push[i] - + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( + sample_ids_to_push[i] + ) + group_size[gid] = next_power group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) for pushed_gid in group_size.keys(): if pushed_gid > gid: - group_members[pushed_gid] = [x + needed_count for x in group_members[pushed_gid]] - - return new_micro_batches, new_exec_times, new_sample_ids_per_gpu, group_members, group_size - + group_members[pushed_gid] = [ + x + needed_count for x in group_members[pushed_gid] + ] + + return ( + new_micro_batches, + new_exec_times, + new_sample_ids_per_gpu, + group_members, + group_size, + ) empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) while empty_gpus: - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size) + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( + fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ) + ) empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) # Assert that no work has been completely removed total_work_after = sum(len(mb) for mb in micro_batches) - assert total_work_after >= total_work_before, f"Work was removed: {total_work_before} -> {total_work_after}" + assert ( + total_work_after >= total_work_before + ), f"Work was removed: {total_work_before} -> {total_work_after}" return micro_batches, leftovers, exec_times, sample_ids_per_gpu - def get_groups_and_subsamples( - self, - sample_id_seqlens, - config, - ): + def get_groups_and_subsamples(self, sample_id_seqlens, config): # TODO: Protect for model parallelism # TODO: Reduce access to file system as much as possible. groups = [] @@ -625,14 +726,17 @@ def get_groups_and_subsamples( # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) while sample_id_seqlens: - mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group(sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus) + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( + sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus + ) groups.append(mb) if len(sample_ids) < self.total_hdp_gpus: sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) sample_id_groups.append(sample_ids) - + return groups, sample_id_groups + def hybrid_context_parallel_forward_backward( forward_step_func, data_iterator, @@ -660,8 +764,8 @@ def hybrid_context_parallel_forward_backward( A group is defined by a set of samples that can run across the CP domain without any barrier. There are many reasons why we may not be able to run endless number of samples within a single group. - For example, if we have 8 GPUs, - if GPU 0-5 are assigned a long sample that requires CP6, + For example, if we have 8 GPUs, + if GPU 0-5 are assigned a long sample that requires CP6, GPU 6-7 are assigned a short sample that requires CP2, The next sample which requires CP4 can be assigned GPU 4-7. But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. @@ -671,7 +775,7 @@ def hybrid_context_parallel_forward_backward( In the future, when we schedule over the entire global batch, we will remove the need for step #2 and number of microbatches will be determined by the number of groups. """ - from .schedules import forward_step, backward_step + from .schedules import backward_step, forward_step cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) # We get data once per global batch and schedule the sub-samples. @@ -689,12 +793,18 @@ def hybrid_context_parallel_forward_backward( # Upto last group, we don't need any sync. with no_sync_func(): for j in range(len(sample_id_groups) - 1): - sample_ids_this_group = sample_id_groups[j][parallel_state.get_data_parallel_rank(with_context_parallel=True)] + sample_ids_this_group = sample_id_groups[j][ + parallel_state.get_data_parallel_rank(with_context_parallel=True) + ] for sub_sample_id in sample_ids_this_group: # Call forward step for each sub-sample sample = batch[sub_sample_id] - partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) - assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[j]}" + partner_cp_size = len( + [True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids] + ) + assert ( + partner_cp_size > 0 + ), f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[j]}" sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) @@ -708,26 +818,36 @@ def hybrid_context_parallel_forward_backward( forward_data_store, config, collect_non_loss_data, - is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, num_microbatches == 1 + ), current_microbatch=num_microbatches - 1, ) total_num_tokens += num_tokens.item() if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - + backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + # Create a barrier at end of each group. # This barrier ensures that all ranks are prepared to change assigned CP group sizes and # no rank is starting a sub-sample ahead of it's partner ranks. - torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) + torch.distributed.barrier( + parallel_state.get_data_parallel_group(with_context_parallel=True) + ) # For the last group, we need to run the last sub-sample out of the context handler. # TODO: Find num sub-samples per group in this group? with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][parallel_state.get_data_parallel_rank(with_context_parallel=True)] + sample_ids_this_group = sample_id_groups[-1][ + parallel_state.get_data_parallel_rank(with_context_parallel=True) + ] for k in range(len(sample_ids_this_group) - 1): sub_sample_id = sample_ids_this_group[k] sample = batch[sub_sample_id] - partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) + partner_cp_size = len( + [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] + ) sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) @@ -741,17 +861,23 @@ def hybrid_context_parallel_forward_backward( forward_data_store, config, collect_non_loss_data, - is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, num_microbatches == 1 + ), current_microbatch=num_microbatches - 1, ) total_num_tokens += num_tokens.item() if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - + # The last sub-sample of the last group of the last microbatch is run out of the context handler. sub_sample_id = sample_ids_this_group[-1] - partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) - assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[-1]}" + partner_cp_size = len( + [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] + ) + assert ( + partner_cp_size > 0 + ), f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[-1]}" sample = batch[sub_sample_id] sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) @@ -766,7 +892,9 @@ def hybrid_context_parallel_forward_backward( forward_data_store, config, collect_non_loss_data, - is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1), + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, num_microbatches == 1 + ), current_microbatch=num_microbatches - 1, ) total_num_tokens += num_tokens.item() @@ -780,4 +908,4 @@ def hybrid_context_parallel_forward_backward( # we don't need to change the loss. # But if the loss calculation is different, then the user needs to define a new loss function # for hybrid context parallel in their training script. - return forward_data_store, total_num_tokens \ No newline at end of file + return forward_data_store, total_num_tokens diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 0d605bae505..62300dc8f37 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1800,7 +1800,9 @@ def is_submodule(module, parent_module, strict=True): ######################## -def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = None, cp_rank: Optional[int] = None): +def get_batch_on_this_cp_rank( + batch: Dict[str, Any], cp_size: Optional[int] = None, cp_rank: Optional[int] = None +): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ @@ -1812,7 +1814,9 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = No # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so # that we can get balanced workload among GPUs in a context parallel group. if cp_size is not None or cp_rank is not None: - assert cp_size is not None and cp_rank is not None, "Both cp_size and cp_rank must be provided for batch slicing" + assert ( + cp_size is not None and cp_rank is not None + ), "Both cp_size and cp_rank must be provided for batch slicing" if cp_size is None: cp_size = parallel_state.get_context_parallel_world_size() @@ -1837,9 +1841,10 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = No return batch + def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params): """ - For a packed sequence, this function returns + For a packed sequence, this function returns 1. The sub-sample of the sequence assigned to this CP rank. 2. The appropriate CP group for the new CP assignment. 3. The updated packed sequence parameters. @@ -1852,7 +1857,7 @@ def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_se """ cu_lengths = packed_seq_params.cu_seqlens_q_padded start_index = cu_lengths[scheduled_id] - end_index = cu_lengths[scheduled_id+1] + end_index = cu_lengths[scheduled_id + 1] # TODO (flexible HCP): New CP size also means new padding requirement. CP4 to CP3 changes padding requirement. for key, data in batch.items(): if key in {'attention_mask', 'cu_seqlens', 'max_seqlen', 'scheduled_id', 'local_cp_size'}: @@ -1868,16 +1873,48 @@ def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_se sub_sample_packed_seq_params = PackedSeqParams( qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[scheduled_id+1] - packed_seq_params.cu_seqlens_q[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[scheduled_id+1] - packed_seq_params.cu_seqlens_kv[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_q_padded[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_kv_padded[scheduled_id]], device="cuda", pin_memory=True), + cu_seqlens_q=torch.tensor( + [ + 0, + packed_seq_params.cu_seqlens_q[scheduled_id + 1] + - packed_seq_params.cu_seqlens_q[scheduled_id], + ], + device="cuda", + pin_memory=True, + ), + cu_seqlens_kv=torch.tensor( + [ + 0, + packed_seq_params.cu_seqlens_kv[scheduled_id + 1] + - packed_seq_params.cu_seqlens_kv[scheduled_id], + ], + device="cuda", + pin_memory=True, + ), + cu_seqlens_q_padded=torch.tensor( + [ + 0, + packed_seq_params.cu_seqlens_q_padded[scheduled_id + 1] + - packed_seq_params.cu_seqlens_q_padded[scheduled_id], + ], + device="cuda", + pin_memory=True, + ), + cu_seqlens_kv_padded=torch.tensor( + [ + 0, + packed_seq_params.cu_seqlens_kv_padded[scheduled_id + 1] + - packed_seq_params.cu_seqlens_kv_padded[scheduled_id], + ], + device="cuda", + pin_memory=True, + ), max_seqlen_q=end_index - start_index, max_seqlen_kv=end_index - start_index, local_cp_size=local_cp_size, ) # TODO: Should we return the sharded sample directly here? - + return batch, cp_group, sub_sample_packed_seq_params From 69ef3999e9a0d598ba5dc8f0651b118e581efa3f Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 14:01:04 -0700 Subject: [PATCH 26/81] Remove scheduled id and other finished TODO items --- .../core/extensions/transformer_engine.py | 1 - megatron/core/packed_seq_params.py | 1 - .../pipeline_parallel/hybrid_cp_schedule.py | 15 ------- megatron/core/utils.py | 44 ------------------- megatron/training/utils.py | 13 ------ pretrain_gpt.py | 15 +------ 6 files changed, 2 insertions(+), 87 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 133043ee9f5..55a34051a23 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -968,7 +968,6 @@ def __init__( self.kept_packed_seq_params = set( field.name for field in dataclasses.fields(PackedSeqParams) ) - self.kept_packed_seq_params.discard("scheduled_id") if get_te_version() < PkgVersion("1.3.0"): # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index ea5e9463187..78a6c7db496 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -21,5 +21,4 @@ class PackedSeqParams: max_seqlen_q: int = None max_seqlen_kv: int = None local_cp_size: int = None - scheduled_id: int = None cp_group: dist.ProcessGroup = None diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 047b5189a74..64c80d113bb 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -618,8 +618,6 @@ def get_groups_and_subsamples( sample_id_seqlens, config, ): - # TODO: Protect for model parallelism - # TODO: Reduce access to file system as much as possible. groups = [] sample_id_groups = [] # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. @@ -675,10 +673,6 @@ def hybrid_context_parallel_forward_backward( cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) # We get data once per global batch and schedule the sub-samples. - # TODO(pmannan): We will change from 3 for loop to 2 for loops. - # groups and sample_id_groups will be provided by the scheduler. - # sample_id_groups will now directly be a list of sub-samples per group and we iterate over it. - # data will be a list of sub-samples with sample_id. # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? data = next(data_iterator) sample_id_groups = data[1] @@ -696,7 +690,6 @@ def hybrid_context_parallel_forward_backward( partner_cp_size = len([True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids]) assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[j]}" sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. output_tensor, num_tokens = forward_step( @@ -721,7 +714,6 @@ def hybrid_context_parallel_forward_backward( torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) # For the last group, we need to run the last sub-sample out of the context handler. - # TODO: Find num sub-samples per group in this group? with no_sync_func(): sample_ids_this_group = sample_id_groups[-1][parallel_state.get_data_parallel_rank(with_context_parallel=True)] for k in range(len(sample_ids_this_group) - 1): @@ -729,7 +721,6 @@ def hybrid_context_parallel_forward_backward( sample = batch[sub_sample_id] partner_cp_size = len([True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids]) sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( @@ -754,7 +745,6 @@ def hybrid_context_parallel_forward_backward( assert partner_cp_size > 0, f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[-1]}" sample = batch[sub_sample_id] sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - sample["scheduled_id"] = torch.tensor(sub_sample_id, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( @@ -775,9 +765,4 @@ def hybrid_context_parallel_forward_backward( torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) - # TODO: Before returning forward_data_store, do we need to change the loss? - # If loss calculation is done as sum(loss_per_token) / sum(total_tokens_per_sample), - # we don't need to change the loss. - # But if the loss calculation is different, then the user needs to define a new loss function - # for hybrid context parallel in their training script. return forward_data_store, total_num_tokens \ No newline at end of file diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 0d605bae505..98ab5676100 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1837,50 +1837,6 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any], cp_size: Optional[int] = No return batch -def get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params): - """ - For a packed sequence, this function returns - 1. The sub-sample of the sequence assigned to this CP rank. - 2. The appropriate CP group for the new CP assignment. - 3. The updated packed sequence parameters. - - Args: - batch: The batch of data to slice. - scheduled_id: The index of the sub-sample to return. - local_cp_size: The size of the CP group. - packed_seq_params: The updated packed sequence parameters. - """ - cu_lengths = packed_seq_params.cu_seqlens_q_padded - start_index = cu_lengths[scheduled_id] - end_index = cu_lengths[scheduled_id+1] - # TODO (flexible HCP): New CP size also means new padding requirement. CP4 to CP3 changes padding requirement. - for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'max_seqlen', 'scheduled_id', 'local_cp_size'}: - continue - batch[key] = data[:, start_index:end_index] - - # TODO (milestone 2): Enable this when we do DPxCP - # cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) - if local_cp_size > 1: - cp_group = parallel_state.get_hybrid_context_parallel_groups(group_size=local_cp_size) - else: - cp_group = None - - sub_sample_packed_seq_params = PackedSeqParams( - qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, packed_seq_params.cu_seqlens_q[scheduled_id+1] - packed_seq_params.cu_seqlens_q[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, packed_seq_params.cu_seqlens_kv[scheduled_id+1] - packed_seq_params.cu_seqlens_kv[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, packed_seq_params.cu_seqlens_q_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_q_padded[scheduled_id]], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, packed_seq_params.cu_seqlens_kv_padded[scheduled_id+1] - packed_seq_params.cu_seqlens_kv_padded[scheduled_id]], device="cuda", pin_memory=True), - max_seqlen_q=end_index - start_index, - max_seqlen_kv=end_index - start_index, - local_cp_size=local_cp_size, - ) - # TODO: Should we return the sharded sample directly here? - - return batch, cp_group, sub_sample_packed_seq_params - - ###################### ### NVTX profiling ### ###################### diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 571101b8299..33a9d0c1c51 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -536,11 +536,6 @@ def _broadcast(item): if "max_seqlen" not in data else data["max_seqlen"].cuda(non_blocking=True) ), - 'scheduled_id': ( - None - if "scheduled_id" not in data - else data["scheduled_id"].cuda(non_blocking=True) - ), 'local_cp_size': ( None if "local_cp_size" not in data @@ -573,7 +568,6 @@ def _broadcast_cu_seqlens(cu_seqlens): _broadcast(batch['position_ids']) _broadcast_cu_seqlens(batch['cu_seqlens']) _broadcast(batch['max_seqlen']) - _broadcast(batch['scheduled_id']) _broadcast(batch['local_cp_size']) elif mpu.is_pipeline_first_stage(): @@ -633,11 +627,6 @@ def _broadcast_cu_seqlens(cu_seqlens): dtype=torch.int32, device=torch.cuda.current_device(), ) - scheduled_id = torch.empty( - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) local_cp_size = torch.empty( 1, dtype=torch.int32, @@ -667,7 +656,6 @@ def _broadcast_cu_seqlens(): _broadcast(position_ids) cu_seqlens = _broadcast_cu_seqlens() _broadcast(max_seqlen) - _broadcast(scheduled_id) _broadcast(local_cp_size) elif mpu.is_pipeline_first_stage(): @@ -707,7 +695,6 @@ def _broadcast_cu_seqlens(): 'position_ids': position_ids, 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, - 'scheduled_id': scheduled_id, 'local_cp_size': local_cp_size, } diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 5a5cf551bbd..c0533452d07 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -19,7 +19,7 @@ from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.utils import get_attr_wrapped_model, is_te_min_version, StragglerDetector, get_sub_sample_on_this_cp_rank +from megatron.core.utils import get_attr_wrapped_model, is_te_min_version, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.utils import ( @@ -62,15 +62,10 @@ def get_batch(data_iterator, vp_stage=None): # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator) - # slice batch along sequence dimension for context parallelism - # batch = get_batch_on_this_cp_rank(batch) - cu_seqlens = batch.pop('cu_seqlens') max_seqlen = batch.pop('max_seqlen') - scheduled_id = batch.pop('scheduled_id') local_cp_size = batch.pop('local_cp_size') - if scheduled_id is not None: - scheduled_id = int(scheduled_id.item()) + if local_cp_size is not None: local_cp_size = int(local_cp_size.item()) if cu_seqlens is not None: @@ -120,8 +115,6 @@ def get_batch(data_iterator, vp_stage=None): batch[key] = data.index_select(1, index) else: # Hybrid CP format assert local_cp_size is not None - assert scheduled_id is not None - # batch, cp_group, packed_seq_params = get_sub_sample_on_this_cp_rank(batch, scheduled_id, local_cp_size, packed_seq_params) if local_cp_size > 1: cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) else: @@ -134,10 +127,6 @@ def get_batch(data_iterator, vp_stage=None): batch[key] = torch.stack([data], 0) sample_length = batch['tokens'].shape[1] # Create packed_seq_params for SBHD format with cp group information. - # TODO(pmannan): Since entire PackedSeqParams is not needed, should we create a new dataclass with our information? - # We will need to update the logic in extensions/transformer_engine to support this. - # Piping through a new dataclass from training script might be adding extra overhead. - # Take ADLR recommendation on this. packed_seq_params = PackedSeqParams( qkv_format="sbhd", cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), From 3bafa2c440435bd16695ec96cdc619b3fe0632db Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 16 Sep 2025 16:19:49 -0700 Subject: [PATCH 27/81] Formatting and cleanup --- .../common/embeddings/rotary_pos_embedding.py | 4 +- megatron/core/packed_seq_params.py | 1 - megatron/core/parallel_state.py | 5 + .../pipeline_parallel/hybrid_cp_schedule.py | 142 ++++++++++-------- megatron/core/utils.py | 3 +- 5 files changed, 90 insertions(+), 65 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 2306f43631c..af8ea822f3b 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -198,8 +198,8 @@ def forward( else: cp_group = None if cp_group is not None: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank + # slice rotary_pos_emb along sequence dimension + # and select the parition of the current CP rank emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index d1e5f9cd305..08ebdac67d8 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass -from typing import List import torch.distributed as dist from torch import Tensor diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 87e92b8d3e7..0d86aeda1a0 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -368,6 +368,11 @@ def create_hierarchical_groups( def create_hybrid_dp_cp_groups(rank, ranks, pg_options): + """ + Creates groups required for hybrid DPxCP. + Creates a new group for every power of 2 up to the number of DPxCP ranks. + Returns a dictionary indexed by group size. + """ hybrid_dp_cp_groups = {} # Generate group for every power of 2 up to the number of CP ranks # We limit the allowed group sizes in order to avoid excessive overhead. diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index bba827f0a62..1563178b951 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -1,13 +1,11 @@ -import heapq from collections import deque from functools import lru_cache from math import ceil, log2 -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple import torch from megatron.core import parallel_state -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.rerun_state_machine import RerunDataIterator @@ -73,6 +71,14 @@ def __iter__(self): return self def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + Each DP rank loads the same number of microbatches but each microbatch + may have a different number of subsamples. + + We find the number of subsamples each rank holds and then gather the + sequence lengths of all subsamples from all ranks. + """ # Collect the number of subsamples from all ranks local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() dp_subsample_count = [ @@ -83,7 +89,7 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: dp_subsample_count, local_len, group=parallel_state.get_data_parallel_group() ) - # Find the maximum number of subsamples across all ranks and pad the subsample_seqlens to the max length + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) max_sub_samples = int(dp_subsample_counts.max().item()) @@ -123,7 +129,15 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: return seqlens_gathered, offsets def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): - # Calculate the global ID for each subsample + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ dp_rank = parallel_state.get_data_parallel_rank() global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() # Create a list of (global_id, seqlen) tuples for scheduling @@ -145,6 +159,14 @@ def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: def reroute_samples_to_hdp_ranks( self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + Since all CP ranks within a DP group have the same data, we only need + to transfer data between matching CP ranks. + """ gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) dp_ranks = torch.distributed.get_process_group_ranks( @@ -197,8 +219,6 @@ def reroute_samples_to_hdp_ranks( [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] ) - # print(f"rank: {torch.distributed.get_rank()}, recv_lens_split: {recv_lens_split} send_lens_split: {send_lens_split}") - recv_ids_sorted = [ gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] ] @@ -244,6 +264,12 @@ def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): return recv_sample_with_id def unpack_batch(self, batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ batch_unpacked = [] for sample in batch: for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): @@ -280,24 +306,24 @@ def __next__(self) -> Any: global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( subsample_seqlens.shape[0], offsets, seqlens_gathered ) - # global_id_seqlens = sorted(global_id_seqlens, key=lambda x: x[1], reverse=True) groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( global_id_seqlens, self.config ) - # sample["groups"] = groups - # sample["sample_id_groups"] = sample_id_groups batch = self.unpack_batch(batch) samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) - # for sample_id, sample in samples_this_rank_with_id.items(): - # sample["sample_id_groups"] return samples_this_rank_with_id, sample_id_groups class BalancedCPScheduler: + """ + This class provides the functionality to form groups of sub-samples + such that all DPxCP ranks have a roughly balanced workload in the group. + """ + def __init__(self, max_seq_len_per_rank: int): self.max_seq_len_per_rank = max_seq_len_per_rank self.num_subsamples = 0 @@ -326,6 +352,11 @@ def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): @lru_cache(maxsize=128) def gpus_needed(self, seq_len: int) -> int: + """ + Calculates the number of GPUs needed for a given sequence length + and max sequence length per CP rank. + This is used to determine the CP size of a sub-sample. + """ return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) def make_buckets_equal( @@ -388,23 +419,26 @@ def next_hdp_group( ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: """ Given a list of (sample_id, sequence_length) tuples, this function aims to assign - sequences to a microbatch such that all GPUs in the CP domain have a roughly balanced workload. - Once each microbatch is roughly balanced, we exit and return the microbatch and the leftover sequences. + sequences in a group such that all GPUs in the DPxCP group have a roughly balanced + workload. Once each group is roughly balanced, we exit and return the + group and the leftover sequences. The function performs the following passes in order to form a balanced microbatch: 1. We create buckets of sequences that are roughly balanced. We try to create as many buckets as possible CP sizes. - 2. Given a bucket has sequences available, we assign the microbatch + 2. Given a bucket has sequences available, we assign the sample a. To a new set of GPUs if there are enough free GPUs. b. To an existing set of GPUs with the lowest load. - 3. We check if the microbatch is balanced whenever we need to move onto a new CP size in the same set of GPUs. - 4. We trim the microbatch if removing the last added sequence helps improve balance. + 3. We check if the group is balanced whenever we need to move onto a new CP size + in the same set of GPUs. + 4. We trim the group if removing the last added sequence helps improve balance. 5. If we run out of sequences to assign and there are empty GPUs, - we redistribute work to empty GPUs by recursively increasing the CP size of a sample until no empty GPUs are left.. + we redistribute work to empty GPUs by recursively increasing the CP size of a + sample until no empty GPUs are left. #TODO: Add clarification on when we check for balance. What does prev_needed do? - Returns (*micro_batches*, *leftover_sample_seqlens*, *exec_times*, *sample_ids_per_gpu*). + Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). """ if not sample_seqlens: return ( @@ -525,30 +559,32 @@ def next_hdp_group( buckets.pop(0) pp_cursor %= max(1, len(buckets)) - # TODO: Should I pre-emptively break out if slack is already within delta? - # Feels like if we have global batch level samples, we will have lots with same CP size. - # So we can just keep adding samples. - # We already have trim workload to handle imbalanced cases. - # TODO: Removing this helps reduce the number of groups when we have lots of samples with same CP size. - # But because we don't exit as soon as we get balanced, even if there is one group available that can take the next sample, + # TODO: Removing this helps reduce the number of groups when we have + # lots of samples with same CP size. + # But because we don't exit as soon as we get balanced, + # even if there is one group available that can take the next sample, # we will keep adding samples to the same group. - # trim_overload() does not help because it only checks if removing the last added sample helps. - # We cannot check after adding every sample because there will always be imbalance if we don't wait for future scheduling. + # trim_overload() does not help because it only checks if removing the + # last added sample helps. + # We cannot check after adding every sample because there will always be imbalance + # if we don't wait for future scheduling. # IMPORTANT: So we need a solution here if needed < prev_needed: - # When we get into a lower CP size in the same group, we can start checking for balance. - # There is still a gotcha here. + # When we get into a lower CP size in the same group, + # we can start checking for balance. There is still a gotcha here. # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. - # We keep assigning group of 2 as we do in descending order but GPU 7/15 never sees a microbatch assigned to it + # We keep assigning group of 2 as we do in descending order but GPU 7/15 + # never sees a microbatch assigned to it # until we run out of samples with CP2. # This means we are never balanced as min(exec_times) will always be 0. - # We need a smart way of identifying that we have run out of big samples and if we are having to - # assign work to a GPU already working, is it because there are empty GPUs? + # We need a smart way of identifying that we have run out of big samples + # and if we are having to assign work to a GPU already working, + # is it because there are empty GPUs? # Would assigning work to empty GPUs first by moving onto next CP bucket help? - # But we need to remember to come back to this CP size bucket and then check for balance. - # Maybe the scheduling algorithm should look at empty GPUs and find work rather than going - # sequence by sequence. + # But we need to remember to come back to this CP size bucket and then + # check for balance. Maybe the scheduling algorithm should look at empty + # GPUs and find work rather than going sequence by sequence. check_balance = True if ( @@ -652,7 +688,7 @@ def fill_empty_gpus( empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] assert not all( work for work in micro_batches[empty_gpu : empty_gpu + needed_count] - ), f"Not enough empty GPUs to expand or there are empty GPUs between work scheduled which is not allowed." + ), f"Empty GPUs were detected but not enough to expand." work_to_push = micro_batches[ group_end_gpu + 1 : empty_gpu ] # This is work of all other subsequent sub-samples @@ -719,9 +755,13 @@ def fill_empty_gpus( return micro_batches, leftovers, exec_times, sample_ids_per_gpu def get_groups_and_subsamples(self, sample_id_seqlens, config): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ groups = [] sample_id_groups = [] - # We assign a sample_id to each sub-sample in order to track the right assignment to each GPU. + # We assign a sample_id to each sub-sample in order to track assignment to each GPU. sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) while sample_id_seqlens: mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( @@ -761,17 +801,12 @@ def hybrid_context_parallel_forward_backward( 3. The number of sub-samples per group each CP rank should execute A group is defined by a set of samples that can run across the CP domain without any barrier. - There are many reasons why we may not be able to run endless number of samples within a single group. + There are many reasons why we may not be able to run endless samples within a single group. For example, if we have 8 GPUs, if GPU 0-5 are assigned a long sample that requires CP6, GPU 6-7 are assigned a short sample that requires CP2, The next sample which requires CP4 can be assigned GPU 4-7. But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. - - As of now, the number of microbatches is pre-determined by GBS and DP size. - We perform the scheduling for each microbatch. - In the future, when we schedule over the entire global batch, we will remove the need for step #2 and - number of microbatches will be determined by the number of groups. """ from .schedules import backward_step, forward_step @@ -796,12 +831,10 @@ def hybrid_context_parallel_forward_backward( partner_cp_size = len( [True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids] ) - assert ( - partner_cp_size > 0 - ), f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[j]}" sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) - # TODO: Find the usage of current_microbatch and is_first_microbatch and how that may affect my usage. + # TODO: Find the usage of current_microbatch and is_first_microbatch and + # how that may affect my usage. output_tensor, num_tokens = forward_step( forward_step_func, new_data_iterator, @@ -861,14 +894,12 @@ def hybrid_context_parallel_forward_backward( if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - # The last sub-sample of the last group of the last microbatch is run out of the context handler. + # The last sub-sample of the last group of the last microbatch is + # run out of the context handler. sub_sample_id = sample_ids_this_group[-1] partner_cp_size = len( [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] ) - assert ( - partner_cp_size > 0 - ), f"rank: {torch.distributed.get_rank()}, sub_sample_id: {sub_sample_id} sample_ids_this_group: {sample_id_groups[-1]}" sample = batch[sub_sample_id] sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) new_data_iterator = RerunDataIterator(iter([sample])) @@ -893,13 +924,4 @@ def hybrid_context_parallel_forward_backward( torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) -<<<<<<< HEAD - return forward_data_store, total_num_tokens -======= - # TODO: Before returning forward_data_store, do we need to change the loss? - # If loss calculation is done as sum(loss_per_token) / sum(total_tokens_per_sample), - # we don't need to change the loss. - # But if the loss calculation is different, then the user needs to define a new loss function - # for hybrid context parallel in their training script. return forward_data_store, total_num_tokens ->>>>>>> dc8139f2ddd22bb1a70ded90f33d2280b6f8c48e diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 4b35e838e06..125663cf7ec 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -6,7 +6,6 @@ import asyncio import functools import hashlib -import heapq import inspect import logging import math @@ -30,7 +29,6 @@ from megatron.core import config from megatron.core.package_info import __version__ as mcore_version -from megatron.core.packed_seq_params import PackedSeqParams try: from torch.distributed._tensor import DTensor @@ -1841,6 +1839,7 @@ def get_batch_on_this_cp_rank( return batch + ###################### ### NVTX profiling ### ###################### From 78dec32c8177d7e4fa8aced993f57471d984db8f Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 17 Sep 2025 15:04:51 -0700 Subject: [PATCH 28/81] Add copyright --- megatron/core/pipeline_parallel/hybrid_cp_schedule.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 1563178b951..5eb62fbd9af 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + from collections import deque from functools import lru_cache from math import ceil, log2 From 6b1d9c4427ce255c564af6df164f124de25bfdca Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 17 Sep 2025 16:18:58 -0700 Subject: [PATCH 29/81] Start testing --- .../pipeline_parallel/hybrid_cp_schedule.py | 1 - .../transformer/multi_latent_attention.py | 2 +- tests/unit_tests/test_parallel_state.py | 30 +++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 5eb62fbd9af..c599a21d382 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -59,7 +59,6 @@ class HybridCPDataLoaderWrapper: def __init__(self, data_iterator, config): self.data_iterator = data_iterator - self.sample_count = 0 self.config = config self.cp_balancing_scheduler = BalancedCPScheduler( max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 9e6b46fd4e5..518e4c2de48 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -492,7 +492,7 @@ def get_query_key_value_tensors( rotary_pos_sin = None packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) else: if self.config.apply_rope_fusion: rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index ddee638776c..c6d5bb9f91d 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -496,3 +496,33 @@ def golden_rank_result_from_past_code( assert expert_dp_group == expert_rank_generator.get_ranks( "dp" ), f"{expert_dp_group} != {expert_rank_generator.get_ranks('dp')}." + + +@pytest.mark.parametrize( + "world_size, tp_size, cp_size, dp_size", + [ + (8, 1, 2, 4), # 8 GPUs, 1 TP, 2 CP, 4 DP + (8, 1, 1, 8), # 8 GPUs, 1 TP, 1 CP, 8 DP + ], +) +def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): + """ + Test that hybrid DPxCP groups are created correctly. + """ + Utils.destroy_model_parallel() + + # Skip if world size doesn't match + actual_world_size = torch.cuda.device_count() + if actual_world_size != world_size: + pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, hybrid_data_context_parallel=True + ) + + dp_cp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + group_sizes = [2**i for i in range(int(log2(dp_cp_size)))][1:] + for group_size in group_sizes: + group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=group_size) + assert group.size() == group_size + + Utils.destroy_model_parallel() From 71c2966057ceeaeb07ff4eab5072363262a1f902 Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Wed, 17 Sep 2025 23:23:44 +0000 Subject: [PATCH 30/81] chore: Format files --- megatron/core/transformer/multi_latent_attention.py | 4 +++- tests/unit_tests/test_parallel_state.py | 13 ++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 518e4c2de48..bd9a0cee01a 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -492,7 +492,9 @@ def get_query_key_value_tensors( rotary_pos_sin = None packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) else: if self.config.apply_rope_fusion: rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index c6d5bb9f91d..6fe6e9bd037 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -500,23 +500,22 @@ def golden_rank_result_from_past_code( @pytest.mark.parametrize( "world_size, tp_size, cp_size, dp_size", - [ - (8, 1, 2, 4), # 8 GPUs, 1 TP, 2 CP, 4 DP - (8, 1, 1, 8), # 8 GPUs, 1 TP, 1 CP, 8 DP - ], + [(8, 1, 2, 4), (8, 1, 1, 8)], # 8 GPUs, 1 TP, 2 CP, 4 DP # 8 GPUs, 1 TP, 1 CP, 8 DP ) def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): """ Test that hybrid DPxCP groups are created correctly. """ Utils.destroy_model_parallel() - + # Skip if world size doesn't match actual_world_size = torch.cuda.device_count() if actual_world_size != world_size: pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, hybrid_data_context_parallel=True + tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, + hybrid_data_context_parallel=True, ) dp_cp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) @@ -524,5 +523,5 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): for group_size in group_sizes: group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=group_size) assert group.size() == group_size - + Utils.destroy_model_parallel() From bff834ccf003e1e1bfbbc7d9e29860ecda92b3e2 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 17 Sep 2025 17:11:26 -0700 Subject: [PATCH 31/81] Fix arg name in test --- tests/unit_tests/test_parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index c6d5bb9f91d..dbafffc3c38 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -516,7 +516,7 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): if actual_world_size != world_size: pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, hybrid_data_context_parallel=True + tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, hybrid_context_parallel=True ) dp_cp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) From ff5600b6b79396237c08305ac8b32dbdb5c49cea Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 17 Sep 2025 17:12:46 -0700 Subject: [PATCH 32/81] Fix arg name in test --- tests/unit_tests/test_parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 6fe6e9bd037..a51c51db88d 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -515,7 +515,7 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): Utils.initialize_model_parallel( tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, - hybrid_data_context_parallel=True, + hybrid_context_parallel=True, ) dp_cp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) From 7ff37e5542fbd608c4813ab91c1a5bbb0b69258e Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 17 Sep 2025 18:30:32 -0700 Subject: [PATCH 33/81] Update RoPE logic to handle none packed_seq_params --- .../common/embeddings/rotary_pos_embedding.py | 19 +++++++++++-------- tests/unit_tests/test_parallel_state.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index af8ea822f3b..287b873efc4 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -190,17 +190,20 @@ def forward( """ emb = self.get_emb(max_seq_len, offset) packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq: - if packed_seq_params.local_cp_size is None: - cp_group = self.cp_group - elif packed_seq_params.local_cp_size > 1: + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size > 1: + # Set CP group to dynamic CP group for CP slicing cp_group = packed_seq_params.cp_group else: + # Set CP group to None to avoid CP slicing cp_group = None - if cp_group is not None: - # slice rotary_pos_emb along sequence dimension - # and select the parition of the current CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) + else: + cp_group = self.cp_group + + if cp_group is not None and cp_group.size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension + # and select the parition of the current CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index a51c51db88d..eb5e424bb0e 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -518,10 +518,10 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): hybrid_context_parallel=True, ) - dp_cp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + dp_cp_size = ps.get_data_parallel_world_size(with_context_parallel=True) group_sizes = [2**i for i in range(int(log2(dp_cp_size)))][1:] for group_size in group_sizes: - group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=group_size) + group = ps.get_hybrid_data_context_parallel_groups(group_size=group_size) assert group.size() == group_size Utils.destroy_model_parallel() From 3c1f66a914da94d52d0eb67e99c9294d833b8b57 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 18 Sep 2025 08:11:35 -0700 Subject: [PATCH 34/81] Fix test import --- tests/unit_tests/test_parallel_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index eb5e424bb0e..56149796dfa 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -1,3 +1,4 @@ +from math import log2 import pytest import torch From aa6c7f7b08264ef9c4390f567eedd1f6d04ef7ef Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Thu, 18 Sep 2025 15:15:25 +0000 Subject: [PATCH 35/81] chore: Format files --- tests/unit_tests/test_parallel_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 56149796dfa..8f97c6ee412 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -1,4 +1,5 @@ from math import log2 + import pytest import torch From 9d761cc4a5778d55b66e3f2ab8f406cf6b460f31 Mon Sep 17 00:00:00 2001 From: tailaim Date: Tue, 23 Sep 2025 02:23:07 -0700 Subject: [PATCH 36/81] fix some bugs and support TP Signed-off-by: tailaim --- megatron/core/parallel_state.py | 8 +- .../pipeline_parallel/hybrid_cp_schedule.py | 141 +++++++++++++----- megatron/training/utils.py | 31 ++-- 3 files changed, 128 insertions(+), 52 deletions(-) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 0d86aeda1a0..be57a00f837 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -865,11 +865,11 @@ def initialize_model_parallel( del os.environ["NCCL_COLLNET_ENABLE"] if hybrid_context_parallel: - assert ( - len(ranks_with_cp) % 2 == 0 - ), "Hybrid context parallel requires an even number of ranks" global _HYBRID_DP_CP_GROUPS - if rank in ranks_with_cp: + for ranks_with_cp in decoder_rank_generator.get_ranks('dp-cp'): + assert ( + len(ranks_with_cp) % 2 == 0 + ), "Hybrid context parallel requires an even number of ranks" _HYBRID_DP_CP_GROUPS.update( create_hybrid_dp_cp_groups( rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index c599a21d382..f6b9355a239 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -152,9 +152,14 @@ def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered) def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) - hdp_rank = torch.distributed.get_process_group_ranks( - parallel_state.get_data_parallel_group() - )[dp_src_rank] + # Since the torch.distributed.get_process_group_ranks + # provides the global rank, we need to consider TP + hdp_rank = ( + torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group())[ + dp_src_rank + ] + // parallel_state.get_tensor_model_parallel_world_size() + ) return hdp_rank def reroute_samples_to_hdp_ranks( @@ -173,6 +178,10 @@ def reroute_samples_to_hdp_ranks( dp_ranks = torch.distributed.get_process_group_ranks( parallel_state.get_data_parallel_group() ) + # Here we actually want to get the DP group's rank within the HDP group, + # we need to consider TP + dp_ranks = [r // parallel_state.get_tensor_model_parallel_world_size() for r in dp_ranks] + data_keys = batch[0].keys() # Create the send plan @@ -290,7 +299,11 @@ def __next__(self) -> Any: """ Get the next item from the dataset, pull scheduling metadata and return it. """ - batch = next(self.data_iterator) + if self.data_iterator is None: + # TP0 reads from data_iterator, others receive via broadcast. + return None, None + else: + batch = next(self.data_iterator) subsample_seqlens = [] for sample in batch: subsample_seqlens.extend( @@ -668,8 +681,10 @@ def fill_empty_gpus( # Find the smallest group size that exists existing_group_sizes = set(group_size.values()) - if not existing_group_sizes: - return # No groups exist, cannot redistribute + assert ( + existing_group_sizes + ), "There should be at least one group existing, cannot reditribute, " + "try to increase 'max-seqlen-per-cp-rank'." min_group_size = min(existing_group_sizes) # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. @@ -811,29 +826,75 @@ def hybrid_context_parallel_forward_backward( """ from .schedules import backward_step, forward_step + def _broadcast(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_tensor_model_parallel_src_rank(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + def _broadcast_num_samples_this_group(num_samples_this_group): + dev = torch.cuda.current_device() + torch.distributed.barrier() + + n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel()) + n = torch.tensor([n], dtype=torch.int64, device=dev) + + _broadcast(n) + n = int(n.item()) + + assert n > 0, "there should be at least 1 sub samples in the group" + num_samples_this_group_broadcast = ( + torch.empty(n, dtype=torch.int32, device=dev) + if num_samples_this_group is None + else num_samples_this_group + ) + _broadcast(num_samples_this_group_broadcast) + return num_samples_this_group_broadcast + cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) # We get data once per global batch and schedule the sub-samples. # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? - data = next(data_iterator) - sample_id_groups = data[1] - batch = data[0] + dpcp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 + + if is_first_tp_rank: + data = next(data_iterator) + sample_id_groups = data[1] + batch = data[0] + else: + data, sample_id_groups, batch = None, None, None + + num_samples_this_group = None + if is_first_tp_rank: + num_samples_this_group = torch.tensor( + [len(group[dpcp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' + ) + + num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) + num_samples_this_group = num_samples_this_group.cpu().numpy() + num_total_groups = num_samples_this_group.shape[0] + # TODO: How does this variable affect downstream logic? num_microbatches = 1 # Upto last group, we don't need any sync. with no_sync_func(): - for j in range(len(sample_id_groups) - 1): - sample_ids_this_group = sample_id_groups[j][ - parallel_state.get_data_parallel_rank(with_context_parallel=True) - ] - for sub_sample_id in sample_ids_this_group: + for j in range(num_total_groups - 1): + sample_ids_this_group = sample_id_groups[j][dpcp_rank] if is_first_tp_rank else None + for i in range(num_samples_this_group[j]): # Call forward step for each sub-sample - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) + if is_first_tp_rank: + sub_sample_id = sample_ids_this_group[i] + sample = batch[sub_sample_id] + partner_cp_size = len( + [True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids] + ) + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) + else: + new_data_iterator = None # TODO: Find the usage of current_microbatch and is_first_microbatch and # how that may affect my usage. output_tensor, num_tokens = forward_step( @@ -865,17 +926,18 @@ def hybrid_context_parallel_forward_backward( # For the last group, we need to run the last sub-sample out of the context handler. with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][ - parallel_state.get_data_parallel_rank(with_context_parallel=True) - ] - for k in range(len(sample_ids_this_group) - 1): - sub_sample_id = sample_ids_this_group[k] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) + sample_ids_this_group = sample_id_groups[-1][dpcp_rank] if is_first_tp_rank else None + for k in range(num_samples_this_group[-1] - 1): + if is_first_tp_rank: + sub_sample_id = sample_ids_this_group[k] + sample = batch[sub_sample_id] + partner_cp_size = len( + [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] + ) + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) + else: + new_data_iterator = None # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, @@ -897,13 +959,16 @@ def hybrid_context_parallel_forward_backward( # The last sub-sample of the last group of the last microbatch is # run out of the context handler. - sub_sample_id = sample_ids_this_group[-1] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] - ) - sample = batch[sub_sample_id] - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) + if is_first_tp_rank: + sub_sample_id = sample_ids_this_group[-1] + sample = batch[sub_sample_id] + partner_cp_size = len( + [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] + ) + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) + else: + new_data_iterator = None # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 33a9d0c1c51..87083f93ed4 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -503,6 +503,7 @@ def get_blend_and_blend_per_split(args): def get_batch_on_this_tp_rank(data_iterator): args = get_args() + assert not (args.pipeline_model_parallel_size > 1 and args.hybrid_context_parallel), "Context parallelism not supported with pipeline parallelism now" def _broadcast(item): if item is not None: @@ -545,7 +546,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() - + # when enabling hybrid context parallel, cu_seqlens is always None n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) _broadcast(n_tensor) @@ -559,7 +560,10 @@ def _broadcast_cu_seqlens(cu_seqlens): buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous() _broadcast(buf) - + if args.hybrid_context_parallel: + seq_len = torch.tensor(batch['tokens'].shape[0], dtype=torch.int32, device=torch.cuda.current_device()) + _broadcast(seq_len) + if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) @@ -591,32 +595,39 @@ def _broadcast_cu_seqlens(cu_seqlens): _broadcast(batch['attention_mask']) else: - + if args.hybrid_context_parallel: + seq_len = torch.tensor(0, dtype=torch.int32, device=torch.cuda.current_device()) + _broadcast(seq_len) + shape = (seq_len.item()) + else: + shape = (args.micro_batch_size, args.seq_length) + tokens = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) labels = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) loss_mask = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.float32, device=torch.cuda.current_device(), ) if args.create_attention_mask_in_dataloader: + shape_attention_mask = (args.micro_batch_size, 1, args.seq_length, args.seq_length) if not args.hybrid_context_parallel else (1, 1, shape[0], shape[0]) attention_mask = torch.empty( - (args.micro_batch_size, 1, args.seq_length, args.seq_length), + shape_attention_mask, dtype=torch.bool, device=torch.cuda.current_device(), ) else: attention_mask = None position_ids = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) @@ -626,12 +637,12 @@ def _broadcast_cu_seqlens(cu_seqlens): 1, dtype=torch.int32, device=torch.cuda.current_device(), - ) + ) if not args.hybrid_context_parallel else None local_cp_size = torch.empty( 1, dtype=torch.int32, device=torch.cuda.current_device(), - ) + ) if args.hybrid_context_parallel else None def _broadcast_cu_seqlens(): dev = torch.cuda.current_device() From 7d6c2b1dec93e54a46a5cad545b4227e37107176 Mon Sep 17 00:00:00 2001 From: tailaim Date: Tue, 23 Sep 2025 07:35:35 -0700 Subject: [PATCH 37/81] fix pad issue, should pad for sequence parallel when TP>1 and consider hdp Signed-off-by: tailaim --- megatron/core/datasets/gpt_dataset.py | 6 ++++++ megatron/training/datasets/sft_dataset.py | 8 +++++--- pretrain_gpt.py | 2 ++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 49ebb45d935..044f2c83d2d 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -58,6 +58,12 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): data_parallel_size: int = 1 """Option to enable data parallelism""" + tensor_model_parallel_size: int = 1 + """Option to enable tensor model parallelism""" + + hybrid_context_parallel: bool = False + """Option to enable hybrid context parallelism""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 64fc15a0b85..a0d81ab2fcb 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -123,9 +123,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1: - # TODO(pmannan): This is a hack to pad for Hybrid DPxCP. - pad_granularity = self.config.context_parallel_size * self.config.data_parallel_size * 2 + if self.config.context_parallel_size > 1 or self.config.tensor_model_parallel_size > 1: + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + cp_pad = cp_pad * self.config.data_parallel_size if self.config.hybrid_context_parallel else cp_pad + tp_pad = self.config.tensor_model_parallel_size + pad_granularity = cp_pad * tp_pad mod_token_count = len(pack_tokens) % pad_granularity if mod_token_count != 0: pad_len = pad_granularity - mod_token_count diff --git a/pretrain_gpt.py b/pretrain_gpt.py index c0533452d07..12d44f2ba3c 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -285,6 +285,8 @@ def core_gpt_dataset_config_from_args(args): mid_level_dataset_surplus=args.mid_level_dataset_surplus, context_parallel_size=args.context_parallel_size, data_parallel_size=args.data_parallel_size, + tensor_model_parallel_size=args.tensor_model_parallel_size, + hybrid_context_parallel=args.hybrid_context_parallel, ) From fbdfe180a567905f3bbbf13ec1316c82b1b90878 Mon Sep 17 00:00:00 2001 From: tailaim Date: Wed, 24 Sep 2025 03:06:33 -0700 Subject: [PATCH 38/81] little adjustment according to comments Signed-off-by: tailaim --- megatron/core/datasets/gpt_dataset.py | 4 +- .../pipeline_parallel/hybrid_cp_schedule.py | 54 +++++++------------ megatron/training/arguments.py | 2 + megatron/training/datasets/sft_dataset.py | 4 +- megatron/training/utils.py | 1 - pretrain_gpt.py | 2 +- 6 files changed, 27 insertions(+), 40 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 044f2c83d2d..c3acd6c9251 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -58,8 +58,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): data_parallel_size: int = 1 """Option to enable data parallelism""" - tensor_model_parallel_size: int = 1 - """Option to enable tensor model parallelism""" + sequence_parallel_size: int = 1 + """Option to enable sequence parallelism""" hybrid_context_parallel: bool = False """Option to enable hybrid context parallelism""" diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index f6b9355a239..2e077587498 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -853,10 +853,23 @@ def _broadcast_num_samples_this_group(num_samples_this_group): _broadcast(num_samples_this_group_broadcast) return num_samples_this_group_broadcast + def _get_new_data_iterator(sample_id_in_group, group_id): + if is_first_tp_rank: + sub_sample_id = sample_ids_this_group[sample_id_in_group] + sample = batch[sub_sample_id] + partner_cp_size = len( + [True for sample_ids in sample_id_groups[group_id] if sub_sample_id in sample_ids] + ) + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) + return new_data_iterator + else: + return None + cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) # We get data once per global batch and schedule the sub-samples. # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? - dpcp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 if is_first_tp_rank: @@ -869,7 +882,7 @@ def _broadcast_num_samples_this_group(num_samples_this_group): num_samples_this_group = None if is_first_tp_rank: num_samples_this_group = torch.tensor( - [len(group[dpcp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' + [len(group[hdp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' ) num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) @@ -882,19 +895,10 @@ def _broadcast_num_samples_this_group(num_samples_this_group): # Upto last group, we don't need any sync. with no_sync_func(): for j in range(num_total_groups - 1): - sample_ids_this_group = sample_id_groups[j][dpcp_rank] if is_first_tp_rank else None + sample_ids_this_group = sample_id_groups[j][hdp_rank] if is_first_tp_rank else None for i in range(num_samples_this_group[j]): # Call forward step for each sub-sample - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[i] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[j] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - else: - new_data_iterator = None + new_data_iterator = _get_new_data_iterator(i, j) # TODO: Find the usage of current_microbatch and is_first_microbatch and # how that may affect my usage. output_tensor, num_tokens = forward_step( @@ -926,18 +930,9 @@ def _broadcast_num_samples_this_group(num_samples_this_group): # For the last group, we need to run the last sub-sample out of the context handler. with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][dpcp_rank] if is_first_tp_rank else None + sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None for k in range(num_samples_this_group[-1] - 1): - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[k] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - else: - new_data_iterator = None + new_data_iterator = _get_new_data_iterator(k, -1) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, @@ -959,16 +954,7 @@ def _broadcast_num_samples_this_group(num_samples_this_group): # The last sub-sample of the last group of the last microbatch is # run out of the context handler. - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[-1] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[-1] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - else: - new_data_iterator = None + new_data_iterator = _get_new_data_iterator(-1, -1) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 94b08e76217..38623c1c692 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -888,6 +888,8 @@ def validate_args(args, defaults={}): if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' + assert not (args.pipeline_model_parallel_size > 1 and args.hybrid_context_parallel), "Context parallelism not supported with pipeline parallelism now" + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index a0d81ab2fcb..6e89201984f 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -123,10 +123,10 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1 or self.config.tensor_model_parallel_size > 1: + if self.config.context_parallel_size > 1 or self.config.sequence_parallel_size > 1: cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 cp_pad = cp_pad * self.config.data_parallel_size if self.config.hybrid_context_parallel else cp_pad - tp_pad = self.config.tensor_model_parallel_size + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 pad_granularity = cp_pad * tp_pad mod_token_count = len(pack_tokens) % pad_granularity if mod_token_count != 0: diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 87083f93ed4..d40bf0be27e 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -503,7 +503,6 @@ def get_blend_and_blend_per_split(args): def get_batch_on_this_tp_rank(data_iterator): args = get_args() - assert not (args.pipeline_model_parallel_size > 1 and args.hybrid_context_parallel), "Context parallelism not supported with pipeline parallelism now" def _broadcast(item): if item is not None: diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 12d44f2ba3c..dec7735efa3 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -285,7 +285,7 @@ def core_gpt_dataset_config_from_args(args): mid_level_dataset_surplus=args.mid_level_dataset_surplus, context_parallel_size=args.context_parallel_size, data_parallel_size=args.data_parallel_size, - tensor_model_parallel_size=args.tensor_model_parallel_size, + sequence_parallel_size=args.tensor_model_parallel_size*args.sequence_parallel, hybrid_context_parallel=args.hybrid_context_parallel, ) From 8e6aa49a01166972b2065d14b3a1b792657f3a21 Mon Sep 17 00:00:00 2001 From: tailaim Date: Wed, 24 Sep 2025 23:05:47 -0700 Subject: [PATCH 39/81] little fix to loop parameter Signed-off-by: tailaim --- megatron/core/pipeline_parallel/hybrid_cp_schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 2e077587498..cc1729ac504 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -931,8 +931,8 @@ def _get_new_data_iterator(sample_id_in_group, group_id): # For the last group, we need to run the last sub-sample out of the context handler. with no_sync_func(): sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None - for k in range(num_samples_this_group[-1] - 1): - new_data_iterator = _get_new_data_iterator(k, -1) + for i in range(num_samples_this_group[-1] - 1): + new_data_iterator = _get_new_data_iterator(i, -1) # Call forward step for each sub-sample output_tensor, num_tokens = forward_step( forward_step_func, From f050f54ff602065ed7b9983aa2d97efa0a8346db Mon Sep 17 00:00:00 2001 From: tailaim Date: Thu, 25 Sep 2025 23:06:41 -0700 Subject: [PATCH 40/81] disable cudagraph anb fsdp when using hybrid cp Signed-off-by: tailaim --- megatron/training/arguments.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 38623c1c692..25120dcab6f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -888,7 +888,11 @@ def validate_args(args, defaults={}): if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - assert not (args.pipeline_model_parallel_size > 1 and args.hybrid_context_parallel), "Context parallelism not supported with pipeline parallelism now" + assert not (args.hybrid_context_parallel + and (args.pipeline_model_parallel_size > 1 + or args.enable_cuda_graph + or args.use_megatron_fsdp) + ), "Hybrid context parallelism not supported with pipeline parallelism, CUDA Graph and FSDP now" # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled From 1e76e8e867a852105f50f955405f1e275ffe2767 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Sun, 28 Sep 2025 21:15:51 -0700 Subject: [PATCH 41/81] Add cp comm type to dynamic CP attn --- megatron/core/extensions/transformer_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 55a34051a23..1e6916ed1d3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1012,16 +1012,19 @@ def forward( ): """Forward.""" if packed_seq_params is not None: + # If Dynamic CP group is provided, update TE DPA CP group if packed_seq_params.cp_group is not None: self.cp_group = packed_seq_params.cp_group - # TODO(pmannan): Add cp_comm_type super().set_context_parallel_group( self.cp_group, torch.distributed.get_process_group_ranks(self.cp_group), TEDotProductAttention.cp_stream, + self.cp_comm_type, ) + # If cp_group is None but local_cp_size is provided, + # Indicates to turn off CP dynamically elif packed_seq_params.local_cp_size is not None: - super().set_context_parallel_group(None, None, None) + super().set_context_parallel_group(None, None, None, self.cp_comm_type) self.kept_packed_seq_params.discard("cp_group") self.kept_packed_seq_params.discard("local_cp_size") packed_seq_kwargs = ( From 3a39a3ab996b35b033434137433b8abd48fde0d8 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Sun, 28 Sep 2025 21:36:56 -0700 Subject: [PATCH 42/81] Cleanup --- megatron/core/datasets/gpt_dataset.py | 4 ++-- pretrain_gpt.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index c3acd6c9251..959f5ef6d28 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -58,8 +58,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): data_parallel_size: int = 1 """Option to enable data parallelism""" - sequence_parallel_size: int = 1 - """Option to enable sequence parallelism""" + sequence_parallel_size: int = 0 + """Option to indicate the sequence parallelism size when using TP""" hybrid_context_parallel: bool = False """Option to enable hybrid context parallelism""" diff --git a/pretrain_gpt.py b/pretrain_gpt.py index dec7735efa3..fafb7406d58 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -85,16 +85,8 @@ def get_batch(data_iterator, vp_stage=None): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore elif local_cp_size is None: # Packed THD format - # assert ( - # cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 - # ), "micro-batch-size must be 1 for packing" cu_seqlens = cu_seqlens[0] - # batch['cu_seqlens'] = cu_seqlens - - # max_seqlen = batch['max_seqlen'] assert max_seqlen.dim() == 1 - # TODO(duncan): can this be kept as a 0-D tensor? - # batch['max_seqlen'] = int(max_seqlen[0].item()) cp_size = get_context_parallel_world_size() if cp_size > 1: # slice batch along sequence dimension for context parallelism @@ -110,7 +102,7 @@ def get_batch(data_iterator, vp_stage=None): cp_rank, ) for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + if key in {'attention_mask'}: continue batch[key] = data.index_select(1, index) else: # Hybrid CP format @@ -140,6 +132,8 @@ def get_batch(data_iterator, vp_stage=None): ) if cp_group is not None and cp_group.size() > 1: + # Each sub-sample of a packed sample is required to be divisible by CP*DP*2 + # or CP*DP*TP*2 (if using sequence parallel) batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) return (*batch.values(), packed_seq_params) From cb72e4e85b7062055070836fb3ac5bb50ca29296 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 16:31:53 -0700 Subject: [PATCH 43/81] Removing benchmarking scripts for MR --- examples/run_hybrid_cp.sh | 206 ---------------------------- examples/run_llama38b_hybrid_cp.sh | 213 ----------------------------- interactive_pytorch25.02.sh | 29 ---- 3 files changed, 448 deletions(-) delete mode 100755 examples/run_hybrid_cp.sh delete mode 100755 examples/run_llama38b_hybrid_cp.sh delete mode 100755 interactive_pytorch25.02.sh diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh deleted file mode 100755 index d792878b599..00000000000 --- a/examples/run_hybrid_cp.sh +++ /dev/null @@ -1,206 +0,0 @@ -#!/bin/bash - -#SBATCH -A coreai_dlalgo_llm -# DFW: batch -# OCI-NRT: batch_block1 -# OCI-IAD: batch_block1,batch_block3,batch_block4,backfill_block1,backfill_block2,backfill_block3,backfill_block4 -#SBATCH -p batch -#SBATCH -t 00:30:00 -#SBATCH --mem=0 -#SBATCH --ntasks-per-node=8 -#SBATCH --nodes=1 -#SBATCH --exclusive -#SBATCH --gpus-per-node=8 -#SBATCH --job-name=hetero_cp_global - -export NCCL_IB_SL=1 -export TOKENIZERS_PARALLELISM="false" -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -#export NVTE_DEBUG=1 -#export NVTE_DEBUG_LEVEL=2 - -USER=$SLURM_JOB_USER - -# Auto-detect batch or interactive mode. -which srun -BATCH=$((1-$?)) - -DEBUG=0 -USE_TILING=1 -USE_CP=1 -USE_TE_CE=0 -USE_FLASH_ATTN=0 -USE_FSDP=0 -USE_CUSTOM_FSDP=0 -PROFILE=0 - -# Remember to update model and job name if running in batch mode!! -if [[ $BATCH -eq 0 ]]; then - DATETIME=`date +'%y-%m-%d-%H-%M-%S'` - MODEL_NAME="interactive_hybrid_cp" - DEBUG=1 -else - MODEL_NAME="interactive_hybrid_cp" -fi - -WORKSPACE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace" -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" -DATACACHE_DIR="${OUTPUT}/data_cache" - -export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache" - -DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl" - -SEQ_LEN=131072 #131072 #81920 #65536 - -if [[ $DEBUG -eq 1 ]]; then - MBZ=1 - BZ=8 - NW=4 - AD=0.0 - HD=0.0 - LI=1 - - # EXTRA_ARGS="--deterministic-mode --use-cpu-initialization" - - NONDETERMINISTIC_ATTN=1 - - NUM_GPU=8 - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - #export NCCL_ALGO=Tree - #export CUBLAS_WORKSPACE_CONFIG=:4096:8 -else - MBZ=1 - BZ=64 - NW=8 - AD=0.0 - HD=0.0 - LI=1 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 - NUM_GPU=8 -fi - -if [[ $USE_CP -eq 1 ]]; then - if [[ $BATCH -eq 1 ]]; then - CP_SIZE=8 - else - CP_SIZE=8 - fi - EXTRA_ARGS+=" --context-parallel-size ${CP_SIZE} " -fi - -if [[ $USE_TE_CE -eq 1 ]]; then - EXTRA_ARGS+=" --enable-te-ce --cross-entropy-loss-fusion " -fi - -if [[ $PROFILE -eq 1 ]]; then - EXTRA_ARGS+="--profile --profile-step-start 7 --profile-step-end 8 " -fi - -# CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" -TP=1 -EXTRA_ARGS+=" --ckpt-format torch_dist --use-distributed-optimizer " -# EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " - -OPTIONS=" \ - --disable-bias-linear \ - --sft \ - --tokenizer-type SFTTokenizer \ - --sft-tokenizer-prompt-format nemotron-h-aligned \ - --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ - --transformer-impl transformer_engine \ - --normalization RMSNorm \ - --norm-epsilon 1e-06 \ - --attention-dropout ${AD} \ - --hidden-dropout ${HD} \ - --untie-embeddings-and-output-weights \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --tensor-model-parallel-size ${TP} \ - --pipeline-model-parallel-size 1 \ - --rerun-mode disabled \ - --num-layers 4 \ - --hidden-size 2048 \ - --ffn-hidden-size 8192 \ - --add-qkv-bias \ - --num-attention-heads 16 \ - --num-workers ${NW} \ - --exit-duration-in-mins 230 \ - --seq-length ${SEQ_LEN} \ - --max-position-embeddings ${SEQ_LEN} \ - --train-samples 100000 \ - --lr-warmup-samples 20000 \ - --micro-batch-size ${MBZ} \ - --global-batch-size ${BZ} \ - --lr 2e-5 \ - --min-lr 0.0 \ - --lr-decay-style cosine \ - --log-interval ${LI} \ - --eval-iters 10 \ - --eval-interval 999999 \ - --data-path ${DATA_TRAIN} \ - --save-interval 1000 \ - --save ${FINETUNE_DIR} \ - --load ${FINETUNE_DIR} \ - --data-cache-path ${DATACACHE_DIR} \ - --use-mcore-models \ - --no-create-attention-mask-in-dataloader \ - --no-mmap-bin-files \ - --split 100,0,0 \ - --clip-grad 1.0 \ - --weight-decay 0.05 \ - --adam-beta1 0.9 \ - --adam-beta2 0.999 \ - --init-method-std 0.014 \ - --bf16 \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - ${EXTRA_ARGS} \ - --distributed-timeout-minutes 60 \ - --calculate-per-token-loss \ - --attention-backend flash \ - --disable-gloo-process-groups \ - --use-dist-ckpt \ - --hybrid-context-parallel \ - --max-seqlen-per-cp-rank 16384 \ -" -# --hybrid-context-parallel \ -# --max-seqlen-per-cp-rank 16384 \ - -# Interactive or batch mode -if [[ $BATCH -eq 0 ]]; then - if [[ $PROFILE -eq 1 ]]; then - nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash_global_64 torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} - else - torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} - fi -else - if [[ $PROFILE -eq 1 ]]; then - run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o without_hetero_cp_global_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" - else - run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" - fi - - DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` - - srun -l --verbose \ - --container-image /lustre/fsw/portfolios/coreai/users/pmannan/workspace/megatron_vlm_25.02_stage3-te_api.sqsh \ - --container-mounts "/lustre" \ - --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ - sh -c "${run_cmd}" - - set +x -fi \ No newline at end of file diff --git a/examples/run_llama38b_hybrid_cp.sh b/examples/run_llama38b_hybrid_cp.sh deleted file mode 100755 index a8d618da5aa..00000000000 --- a/examples/run_llama38b_hybrid_cp.sh +++ /dev/null @@ -1,213 +0,0 @@ -#!/bin/bash - -#SBATCH -A coreai_dlalgo_genai -# DFW: batch -# OCI-NRT: batch_block1 -# OCI-IAD: batch_block1,batch_block3,batch_block4,backfill_block1,backfill_block2,backfill_block3,backfill_block4 -#SBATCH -p batch -#SBATCH -t 00:20:00 -#SBATCH --mem=0 -#SBATCH --ntasks-per-node=8 -#SBATCH --nodes=4 -#SBATCH --exclusive -#SBATCH --gpus-per-node=8 -#SBATCH --job-name=hetero_cp_global - -export NCCL_IB_SL=1 -export TOKENIZERS_PARALLELISM="false" -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -#export NVTE_DEBUG=1 -#export NVTE_DEBUG_LEVEL=2 - -USER=$SLURM_JOB_USER - -# Auto-detect batch or interactive mode. -which srun -BATCH=$((1-$?)) - -DEBUG=0 -USE_TILING=1 -USE_CP=1 -USE_TE_CE=0 -USE_FLASH_ATTN=0 -USE_FSDP=0 -USE_CUSTOM_FSDP=0 -PROFILE=0 - -# Remember to update model and job name if running in batch mode!! -if [[ $BATCH -eq 0 ]]; then - DATETIME=`date +'%y-%m-%d-%H-%M-%S'` - MODEL_NAME="interactive_hybrid_cp" - DEBUG=1 -else - MODEL_NAME="interactive_hybrid_cp" -fi - -WORKSPACE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace" -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" -DATACACHE_DIR="${OUTPUT}/data_cache" - -export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache" - -DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl" - -SEQ_LEN=131072 #131072 #81920 #65536 - -if [[ $DEBUG -eq 1 ]]; then - MBZ=1 - BZ=64 - NW=4 - AD=0.0 - HD=0.0 - LI=1 - - # EXTRA_ARGS="--deterministic-mode --use-cpu-initialization" - - NONDETERMINISTIC_ATTN=1 - - NUM_GPU=8 - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - #export NCCL_ALGO=Tree - #export CUBLAS_WORKSPACE_CONFIG=:4096:8 -else - MBZ=1 - BZ=64 - NW=8 - AD=0.0 - HD=0.0 - LI=1 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 - NUM_GPU=8 -fi - -if [[ $USE_CP -eq 1 ]]; then - if [[ $BATCH -eq 1 ]]; then - CP_SIZE=8 - else - CP_SIZE=8 - fi - EXTRA_ARGS+=" --context-parallel-size ${CP_SIZE} " -fi - -if [[ $USE_TE_CE -eq 1 ]]; then - EXTRA_ARGS+=" --enable-te-ce --cross-entropy-loss-fusion " -fi - -if [[ $PROFILE -eq 1 ]]; then - EXTRA_ARGS+="--profile --profile-step-start 7 --profile-step-end 8 --profile-ranks 0 16 " -fi - -# CHECKPOINT_DIR="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/output/video_sft_stage2_qwen_2p5_7b_radio_research_cp_0429_tp2/checkpoints" -TP=1 -EXTRA_ARGS+=" --ckpt-format torch_dist --use-distributed-optimizer " -# EXTRA_ARGS+=" --overlap-param-gather --overlap-grad-reduce " -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -EXTRA_ARGS+=" --use-precision-aware-optimizer --main-grads-dtype bf16 --main-params-dtype fp16 --exp-avg-dtype fp16 --exp-avg-sq-dtype fp16 " - -OPTIONS=" \ - --disable-bias-linear \ - --sft \ - --legacy-tokenizer \ - --tokenizer-type SFTTokenizer \ - --sft-tokenizer-prompt-format nemotron-h-aligned \ - --tokenizer-model /lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct \ - --transformer-impl transformer_engine \ - --normalization RMSNorm \ - --norm-epsilon 1e-06 \ - --attention-dropout ${AD} \ - --hidden-dropout ${HD} \ - --untie-embeddings-and-output-weights \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --tensor-model-parallel-size ${TP} \ - --pipeline-model-parallel-size 1 \ - --rerun-mode disabled \ - --num-layers 28 \ - --hidden-size 4096 \ - --ffn-hidden-size 14336 \ - --add-qkv-bias \ - --num-attention-heads 32 \ - --num-workers ${NW} \ - --exit-duration-in-mins 230 \ - --seq-length ${SEQ_LEN} \ - --max-position-embeddings ${SEQ_LEN} \ - --train-samples 100000 \ - --lr-warmup-samples 20000 \ - --micro-batch-size ${MBZ} \ - --global-batch-size ${BZ} \ - --lr 2e-5 \ - --min-lr 0.0 \ - --lr-decay-style cosine \ - --log-interval ${LI} \ - --eval-iters 10 \ - --eval-interval 999999 \ - --data-path ${DATA_TRAIN} \ - --save-interval 1000 \ - --save ${FINETUNE_DIR} \ - --load ${FINETUNE_DIR} \ - --data-cache-path ${DATACACHE_DIR} \ - --use-mcore-models \ - --no-create-attention-mask-in-dataloader \ - --no-mmap-bin-files \ - --split 100,0,0 \ - --clip-grad 1.0 \ - --weight-decay 0.05 \ - --adam-beta1 0.9 \ - --adam-beta2 0.999 \ - --init-method-std 0.014 \ - --bf16 \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - ${EXTRA_ARGS} \ - --distributed-timeout-minutes 60 \ - --calculate-per-token-loss \ - --attention-backend flash \ - --group-query-attention \ - --num-query-groups 8 \ - --disable-gloo-process-groups \ - --use-dist-ckpt \ - --hybrid-context-parallel \ - --max-seqlen-per-cp-rank 8192 \ -" - -# --recompute-method block \ -# --recompute-num-layers 14 \ -# --recompute-granularity full \ - -# Interactive or batch mode -if [[ $BATCH -eq 0 ]]; then - if [[ $PROFILE -eq 1 ]]; then - nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash_global_64 torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} - else - torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} - fi -else - if [[ $PROFILE -eq 1 ]]; then - run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o llama38b_hybrid_cp_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" - else - run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" - fi - - DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` - - echo "run_cmd: ${run_cmd}" - - srun -l --verbose \ - --container-image /lustre/fsw/portfolios/llmservice/users/matthieul/docker/megatron-dev-img-05142025-pytorch-dev-te-cd37379-editable-energon-mamba-fix-vlmeval-pad-conv.sqsh \ - --container-mounts "/lustre" \ - --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ - sh -c "${run_cmd}" - - set +x -fi \ No newline at end of file diff --git a/interactive_pytorch25.02.sh b/interactive_pytorch25.02.sh deleted file mode 100755 index 3333cd934df..00000000000 --- a/interactive_pytorch25.02.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# H100 -CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-bd46613.sqsh" -CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-main-6.0.1.sqsh" -CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr.sqsh" -CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr-onelogger.sqsh" -CONTAINER_IMAGE="/lustre/fsw/portfolios/llmservice/users/trintamaki/workspace/containers/megatron-dev-img-03232025-pytorch25.02-te-main-160be21-energon-develop-max-pr-onelogger-max-pr2.sqsh" -CONTAINER_IMAGE="/lustre/fsw/portfolios/coreai/users/pmannan/workspace/megatron_vlm_25.02-te_api.sqsh" - -# Set partitions based on hostname -if [[ $(hostname) == *"oci-iad"* ]]; then - PARTITIONS="interactive,batch_singlenode,backfill_singlenode,backfill_block1,backfill_block3,backfill_block4,batch_block1,batch_block2,batch_block3,batch_block4" -elif [[ $(hostname) == *"cw-dfw"* ]]; then - PARTITIONS="interactive,batch" -elif [[ $(hostname) == *"oci-nrt"* ]]; then - PARTITIONS="interactive,batch_block1,backfill,batch_singlenode" -else - PARTITIONS="interactive" -fi - -srun -p ${PARTITIONS} -A coreai_dlalgo_genai -N 1 --pty \ - --container-image /lustre/fsw/portfolios/llmservice/users/matthieul/docker/megatron-dev-img-05142025-pytorch-dev-te-cd37379-editable-energon-mamba-fix-vlmeval-pad-conv.sqsh \ - --container-mounts "/lustre" \ - --gpus 8 \ - --exclusive \ - --job-name "coreai_dlalgo_genai-megatron-dev:interactive" \ - -t 1:00:00 \ - bash -l \ No newline at end of file From 6aebf37ff27e1e98df5012d1c53ad187cde2be31 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 16:35:30 -0700 Subject: [PATCH 44/81] Restore SFT dataset/tokenizer to main --- megatron/training/datasets/sft_dataset.py | 205 +++++++------------ megatron/training/tokenizer/sft_tokenizer.py | 8 +- 2 files changed, 77 insertions(+), 136 deletions(-) diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 6e89201984f..e4d8a6faf24 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -68,138 +68,79 @@ def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowL def __len__(self) -> int: return self.num_samples - def _split_conversations(self, merged_conversations): - split_conversations = [] - current = [] - for msg in merged_conversations: - # Whenever we see a new system message, start a new conversation - if msg["role"] == "system": - if current: # If previously accumulating a conversation, then store it - split_conversations.append(current) - current = [msg] # Then start the new conversation - else: - current.append(msg) # Continue accumulating the current conversation - if current: # Store any remaining conversation - split_conversations.append(current) - return split_conversations - def __getitem__(self, idx: int) -> Dict[str, Any]: tokenizer = self.config.tokenizer - pack_length = self.config.sequence_length - - merged_conversations = self.dataset[int(self.indices[idx % len(self.indices)])] - split_conversations = self._split_conversations(merged_conversations) - - def extend_with_padding(tokens, targets, positions, pad_len): - tokens.extend([pad] * pad_len) - targets.extend([pad] * pad_len) - positions.extend(range(positions[-1]+1, positions[-1]+1+pad_len)) - - pack_tokens = [] - pack_targets = [] - pack_positions = [] - cu_seqlens = [0] - eod = tokenizer.eod - pad = tokenizer.pad - # TODO(duncan): Track number of convs dropped and/or truncated and amount of end-padding - for conversation in split_conversations: - - tokens, targets = tokenizer.tokenize_conversation( - conversation, return_target=True, add_generation_prompt=False - ) - - tokens_list = tokens.tolist() - targets_list = targets.tolist() - - # Add EOD, unless it's already present - if tokens_list[-1] != eod: - tokens_list.append(eod) - targets_list.append(eod) - - pack_tokens.extend(tokens_list) - pack_targets.extend(targets_list) - - assert not self.config.reset_position_ids - pack_positions.extend(range(len(tokens_list))) - - if self.config.context_parallel_size > 1 or self.config.sequence_parallel_size > 1: - cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 - cp_pad = cp_pad * self.config.data_parallel_size if self.config.hybrid_context_parallel else cp_pad - tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 - pad_granularity = cp_pad * tp_pad - mod_token_count = len(pack_tokens) % pad_granularity - if mod_token_count != 0: - pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) - - # TODO(duncan): Consider also padding to multiple of number of tokens here. This might - # be needed for efficiency (and potentially set via command-line argument). - - cu_seqlens.append(len(pack_tokens)) - - # Handle any necessary truncation - # - # Consider the case where the previous iteration led to - # len(pack_tokens) == pack_length. Then pack_tokens[pack_length-1] == eod. On this - # current iteration len(pack_tokens) >= pack_length + 1. Truncation here will then - # strip off the eod from the previous iteration and re-apply it. - # - # Consider the case where the previous iteration led to - # len(pack_tokens) == pack_length - 1. Then pack_tokens[pack_length-2] == eod. On - # this current iteration len(pack_tokens) >= pack_length + 1. Truncation here will - # then apply a second eod at location pack_tokens[pack_length-1]. So it is possible - # to have two eod tokens in a row. I'm not sure if this is a problem. - # - if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment - max_body = pack_length - 1 - # Truncate on the left. TODO(duncan): Consider optionally trunc. on the right - pack_tokens = pack_tokens[-max_body:] - pack_targets = pack_targets[-max_body:] - pack_tokens.extend([eod, pad]) - pack_targets.extend([eod, pad]) - pack_positions = pack_positions[:pack_length+1] - # Note len({pack_tokens, pack_targets, pack_positions}) should be pack_length + 1 - cu_seqlens[-1] = len(pack_tokens) - 1 - break - - # Handle any necessary padding - if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment - pad_len = pack_length + 1 - len(pack_tokens) - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) - # Note len({pack_tokens, pack_targets, pack_positions}) should be pack_length + 1 - cu_seqlens[-1] = len(pack_tokens) - 1 - - assert len(pack_tokens) == pack_length + 1 - assert len(pack_targets) == pack_length + 1 - assert len(pack_positions) == pack_length + 1 - - # Align and convert to tensors - input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) - labels = torch.tensor(pack_targets[1:], dtype=torch.int64) - position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) - - loss_mask = torch.ones(pack_length, dtype=torch.float32) - loss_mask[labels == pad] = 0.0 # Mask paddings - loss_mask[labels == IGNORE_INDEX] = 0.0 # mask prompts - - # TODO(duncan): Optionally create an attention mask - assert not self.config.create_attention_mask and not self.config.reset_attention_mask - # attention_mask = None - - assert len(cu_seqlens) >= 2 - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) - # Calculating max_seqlen here, rather than incrementally above, because of possible - # effects of truncation and padding - adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor - - return { - 'tokens': input_ids, - 'labels': labels, - # 'attention_mask': attention_mask, # PyTorch collate cannot handle NoneType - 'loss_mask': loss_mask, - 'position_ids': position_ids, - 'cu_seqlens': cu_seqlens, - 'max_seqlen': max_seqlen, - } + max_seq_len = self.config.sequence_length + + conversation_list = self.dataset[int(self.indices[idx % len(self.indices)])] + tokens, target = tokenizer.tokenize_conversation( + conversation_list, return_target=True, add_generation_prompt=False + ) + + force_eod_length = int(tokenizer.force_eod) + + if len(tokens) > max_seq_len - force_eod_length: + tokens = tokens[: max_seq_len - force_eod_length] + target = target[: max_seq_len - force_eod_length] + + # padding + num_tokens = len(tokens) + force_eod_length + padding_len = max_seq_len - num_tokens + assert padding_len >= 0 + filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) + + tokens = np.array(tokens.tolist() + filler, dtype=np.int64) + target = np.array(target.tolist() + filler, dtype=np.int64) + + tokens = torch.tensor(tokens) + target = torch.tensor(target) + + tokens = tokens[:-1].contiguous() + target = target[1:].contiguous() + + loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( + max_seq_len, target, tokenizer.pad + ) + + if self.config.create_attention_mask: + ret = { + 'tokens': tokens, + 'labels': target, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + ret = { + 'tokens': tokens, + 'labels': target, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + return ret + + def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): + """Build masks and position id for left to right model for SFT""" + + assert not self.config.reset_position_ids and not self.config.reset_attention_mask + + # Position ids. + position_ids = torch.arange(max_seq_len, dtype=torch.long) + + # Loss mask. + loss_mask = torch.ones(max_seq_len, dtype=torch.float) + loss_mask[target == pad_token] = 0.0 # mask paddings + loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts + + if self.config.create_attention_mask: + attention_mask = torch.tril( + torch.ones((seq_length, seq_length), device=data.device) + ).unsqueeze(0) + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + else: + attention_mask = None + + return loss_mask, position_ids, attention_mask diff --git a/megatron/training/tokenizer/sft_tokenizer.py b/megatron/training/tokenizer/sft_tokenizer.py index 4994851516b..4a941fc180b 100644 --- a/megatron/training/tokenizer/sft_tokenizer.py +++ b/megatron/training/tokenizer/sft_tokenizer.py @@ -42,11 +42,11 @@ def __init__( self._vocab_size = len(tokenizer) self._tokenizer = tokenizer - if prompt_format == "nemotron-h-aligned": + if prompt_format == "nemotron-nano-v2": self._prompt_config = PromptConfig( - assistant_prefix_len=0, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_h_aligned_custom_template, + assistant_prefix_len=3, + pad_token_id=tokenizer.convert_tokens_to_ids(""), + custom_chat_template=nemotron_nano_v2_custom_template, has_bos=False, has_system_role=True, ) From f5c0e76ca905b08cb59dbc78e48621c3cfdc0275 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 16:46:20 -0700 Subject: [PATCH 45/81] Update pre-train script --- pretrain_gpt.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index fafb7406d58..7358e53f75d 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -13,7 +13,6 @@ get_context_parallel_rank, get_context_parallel_world_size, ) -from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPDatasetWrapper from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.enums import ModelType @@ -68,26 +67,24 @@ def get_batch(data_iterator, vp_stage=None): if local_cp_size is not None: local_cp_size = int(local_cp_size.item()) - if cu_seqlens is not None: - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens[0], - cu_seqlens_kv=cu_seqlens[0], - cu_seqlens_q_padded=cu_seqlens[0], - cu_seqlens_kv_padded=cu_seqlens[0], - max_seqlen_q=int(max_seqlen[0].item()), - max_seqlen_kv=int(max_seqlen[0].item()), - ) - else: - packed_seq_params = None - if cu_seqlens is None and local_cp_size is None: # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore + packed_seq_params = None elif local_cp_size is None: # Packed THD format cu_seqlens = cu_seqlens[0] assert max_seqlen.dim() == 1 + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, + max_seqlen_q=int(max_seqlen[0].item()), + max_seqlen_kv=int(max_seqlen[0].item()), + ) + cp_size = get_context_parallel_world_size() if cp_size > 1: # slice batch along sequence dimension for context parallelism assert tex is not None and is_te_min_version("1.10.0"), ( @@ -113,6 +110,7 @@ def get_batch(data_iterator, vp_stage=None): cp_group = None # Convert [seqlen] to [1, seqlen] similar to default collate_fn + # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn for key, data in batch.items(): if key in ['attention_mask']: continue @@ -132,8 +130,8 @@ def get_batch(data_iterator, vp_stage=None): ) if cp_group is not None and cp_group.size() > 1: - # Each sub-sample of a packed sample is required to be divisible by CP*DP*2 - # or CP*DP*TP*2 (if using sequence parallel) + # When using hybrid_context_parallel, each sub-sample of a packed sample is + # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) return (*batch.values(), packed_seq_params) @@ -221,7 +219,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa global stimer with stimer(bdata=True): vp_stage = get_attr_wrapped_model(model, "vp_stage") - tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator) + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator, vp_stage) timers('batch-generator').stop() with stimer: @@ -308,8 +306,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None dataset_type, train_val_test_num_samples, partial(is_dataset_built_on_rank, vp_stage=vp_stage), config ).build() - # train_ds = HybridCPDatasetWrapper(train_ds) - print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds From 9d96178d0ffbcd4cfa6c5792ffc48496ab152c90 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 16:55:31 -0700 Subject: [PATCH 46/81] Lint change --- megatron/core/transformer/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index f465831e8d1..89c6a1fd250 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -47,7 +47,9 @@ rearrange = None try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From b7aff32feab1a60d6d28ff926f3e02c48168e3ae Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Thu, 2 Oct 2025 23:59:26 +0000 Subject: [PATCH 47/81] chore: Format files --- megatron/core/transformer/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 89c6a1fd250..f465831e8d1 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -47,9 +47,7 @@ rearrange = None try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From 994e65fc84875b8d562d8c46d3ea14a0ac996d94 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 17:02:28 -0700 Subject: [PATCH 48/81] Add assert --- megatron/core/extensions/transformer_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bb7190a1dc8..33622be3a5b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1030,6 +1030,9 @@ def forward( # If cp_group is None but local_cp_size is provided, # Indicates to turn off CP dynamically elif packed_seq_params.local_cp_size is not None: + assert ( + packed_seq_params.local_cp_size == 1 + ), "local_cp_size must be == 1 if provided without cp_group" super().set_context_parallel_group(None, None, None, self.cp_comm_type) self.kept_packed_seq_params.discard("cp_group") self.kept_packed_seq_params.discard("local_cp_size") From aa6cf10285153c32b48384afaf246e50d6cd5e02 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 17:42:40 -0700 Subject: [PATCH 49/81] Add support for Yarn Rope --- .../embeddings/yarn_rotary_pos_embedding.py | 31 ++++++++++++++++--- megatron/core/models/gpt/gpt_model.py | 4 ++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 6cfeee98033..9e8a4ee30fd 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -99,13 +99,12 @@ def __init__( ) @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: """Forward pass of Yarn Rotary Embedding. Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. Returns: Tensor: Embeddings after applying Yarn RoPE. @@ -151,10 +150,34 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) - emb = torch.cat((freqs, freqs), dim=-1) # emb [seq_length, .., dim] emb = emb[:, None, None, :] - if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq: + return emb, _mscale + + def forward(self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: + """Forward pass of Yarn Rotary Embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + + Returns: + Tensor: Embeddings after applying Yarn RoPE. + """ + emb, _mscale = self.get_emb(max_seq_len, offset) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size > 1: + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group + else: + # Set CP group to None to avoid CP slicing + cp_group = None + else: + cp_group = self.cp_group + if cp_group is not None and cp_group.size() > 1 and not packed_seq: # slice rotary_pos_emb along sequence dimension # and select the parition of the current CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb, _mscale def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 25f9d6c0de1..adf43993f9f 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -348,7 +348,9 @@ def _preprocess( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) - rotary_pos_emb, _ = self.rotary_pos_emb(rotary_seq_len) + rotary_pos_emb, _ = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) else: raise NotImplementedError( "Flash decoding uses precomputed cos and sin for RoPE, not implemented in " From f3a6e17cae99c90befc21a0b3298f86d333544ce Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 17:47:47 -0700 Subject: [PATCH 50/81] Add support for MultimodalRope --- .../common/embeddings/rotary_pos_embedding.py | 16 +++++++++++++--- megatron/core/models/gpt/gpt_model.py | 4 +++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 287b873efc4..e71629ec069 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -306,13 +306,14 @@ def __init__( else parallel_state.get_context_parallel_group(check_initialized=False) ) - def forward(self, position_ids: torch.Tensor, mrope_section: List[int]) -> Tensor: + def forward(self, position_ids: torch.Tensor, mrope_section: List[int], packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: """Forward pass of multimodal RoPE embedding. Args: position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. Returns: Tensor: Embeddings after applying RoPE. @@ -345,8 +346,17 @@ def forward(self, position_ids: torch.Tensor, mrope_section: List[int]) -> Tenso # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if self.cp_group is not None and self.cp_group.size() > 1: + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size > 1: + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group + else: + # Set CP group to None to avoid CP slicing + cp_group = None + else: + cp_group = self.cp_group + if cp_group is not None and cp_group.size() > 1: # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index adf43993f9f..871bc56c74f 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -358,7 +358,9 @@ def _preprocess( ) elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: if self.training or not self.config.flash_decode: - rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) + rotary_pos_emb = self.rotary_pos_emb( + position_ids, self.mrope_section, packed_seq_params=packed_seq_params + ) else: # Flash decoding uses precomputed cos and sin for RoPE raise NotImplementedError( From a0fb848173ae1fe6404ed4df63425c740c9585f5 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 18:11:42 -0700 Subject: [PATCH 51/81] General and RoPE support cleanup --- megatron/core/model_parallel_config.py | 6 +++--- .../models/common/embeddings/rotary_pos_embedding.py | 7 ++++++- .../common/embeddings/yarn_rotary_pos_embedding.py | 4 +++- megatron/core/pipeline_parallel/hybrid_cp_schedule.py | 5 ++--- megatron/core/transformer/attention.py | 4 +++- megatron/core/transformer/multi_latent_attention.py | 11 +++++++++-- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index ffa7579360f..c2eba8f594f 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -52,10 +52,10 @@ class ModelParallelConfig: type. """ - max_seqlen_per_cp_rank: Optional[int] = None + max_seqlen_per_dp_cp_rank: Optional[int] = None """ - Maximum sequence length per CP rank. This is used to calculate the number and length - of sub-samples assigned to each CP rank when using hybrid data x context parallel. + Maximum sequence length per DPxCP rank. This is used to calculate the number and length + of sub-samples assigned to each CP rank when using hybrid_context_parallel. """ hybrid_context_parallel: bool = False diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index e71629ec069..a8f94f46886 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -306,7 +306,12 @@ def __init__( else parallel_state.get_context_parallel_group(check_initialized=False) ) - def forward(self, position_ids: torch.Tensor, mrope_section: List[int], packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: + def forward( + self, + position_ids: torch.Tensor, + mrope_section: List[int], + packed_seq_params: Optional[PackedSeqParams] = None, + ) -> Tensor: """Forward pass of multimodal RoPE embedding. Args: diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 9e8a4ee30fd..c606dda3bb3 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -152,7 +152,9 @@ def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: emb = emb[:, None, None, :] return emb, _mscale - def forward(self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None) -> Tensor: + def forward( + self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + ) -> Tensor: """Forward pass of Yarn Rotary Embedding. Args: diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index cc1729ac504..d69398d1fc3 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -54,14 +54,14 @@ class HybridCPDataLoaderWrapper: Args: data_iterator: The original data_iterator to wrap around - config: The config object containing the max_seqlen_per_cp_rank + config: The config object containing the max_seqlen_per_dp_cp_rank """ def __init__(self, data_iterator, config): self.data_iterator = data_iterator self.config = config self.cp_balancing_scheduler = BalancedCPScheduler( - max_seq_len_per_rank=self.config.max_seqlen_per_cp_rank + max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank ) self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( with_context_parallel=True @@ -866,7 +866,6 @@ def _get_new_data_iterator(sample_id_in_group, group_id): else: return None - cp_balancing_scheduler = BalancedCPScheduler(max_seq_len_per_rank=config.max_seqlen_per_cp_rank) # We get data once per global batch and schedule the sub-samples. # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index f465831e8d1..89c6a1fd250 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -47,7 +47,9 @@ rearrange = None try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 3658e176ae0..a252227074f 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -478,6 +478,11 @@ def get_query_key_value_tensors( assert ( hidden_states.ndim == 3 ), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + if packed_seq_params is not None: + assert ( + packed_seq_params.local_cp_size is None + ), "hybrid_context_parallel is not supported with MLA yet and is planned for future. \ + Please disable hybrid_context_parallel." inference_context = deprecate_inference_params(inference_context, inference_params) @@ -509,9 +514,11 @@ def get_query_key_value_tensors( and fused_apply_mla_rope_for_kv is not None ), "Fused MLA RoPE apply is not imported successfully" else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb, mscale = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded else: From b0b0fabaa058d776e48036e6974b231df2036b4f Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 2 Oct 2025 18:18:25 -0700 Subject: [PATCH 52/81] Data sampler cleanup --- megatron/legacy/data/data_samplers.py | 16 ---------------- megatron/training/arguments.py | 5 +++-- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/megatron/legacy/data/data_samplers.py b/megatron/legacy/data/data_samplers.py index 17cabd8d055..bb6a36e3903 100644 --- a/megatron/legacy/data/data_samplers.py +++ b/megatron/legacy/data/data_samplers.py @@ -11,22 +11,6 @@ from megatron.core import mpu from megatron.core.datasets.utils import Split -# class HybridCPCollator: -# def __init__(self): -# pass - -# def __call__(self, batch: List[Dict[str, Any]]): -# # IF WE WANT TO COLLATE BEFORE SCHEDULER -# # We get a list of samples when pulling multiple microbatches at once for scheduling with Hybrid DPxCP. -# # This would require us to change the scheduler and the data loader logic to support [1, S] instead of [S]. -# # if isinstance(batch, list): -# # for idx, sample in enumerate(batch): -# # collated_sample = torch.utils.data.default_collate([sample]) -# # batch[idx] = collated_sample -# # else: -# # batch = torch.utils.data.default_collate(batch) - -# return batch def build_pretraining_data_loader(dataset, consumed_samples): """Build dataloader given an input dataset.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9e3772050ce..f5a7ae7ee5b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -906,8 +906,9 @@ def validate_args(args, defaults={}): assert not (args.hybrid_context_parallel and (args.pipeline_model_parallel_size > 1 or args.enable_cuda_graph - or args.use_megatron_fsdp) - ), "Hybrid context parallelism not supported with pipeline parallelism, CUDA Graph and FSDP now" + or args.use_megatron_fsdp + or args.dataloader_type != 'single') + ), "Hybrid context parallelism not supported with pipeline parallelism, CUDA Graph, FSDP and non-single dataloader type now" # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled From 335994ce5e2a1b23d4af25f094809552c43f8875 Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Fri, 3 Oct 2025 01:22:54 +0000 Subject: [PATCH 53/81] chore: Format files --- megatron/core/transformer/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 89c6a1fd250..f465831e8d1 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -47,9 +47,7 @@ rearrange = None try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From ba99cadc4cf15fd1c4c683e4965cb804bb6c993f Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 3 Oct 2025 14:12:23 -0700 Subject: [PATCH 54/81] Cleanup hybrid_cp_schedule --- megatron/core/parallel_state.py | 5 +- .../pipeline_parallel/hybrid_cp_schedule.py | 71 +++++++++++-------- megatron/training/utils.py | 1 - 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index be57a00f837..525a131e740 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -3,6 +3,7 @@ """Model and data parallel groups.""" import os +import logging import warnings from datetime import timedelta from math import log2 @@ -13,6 +14,8 @@ from .utils import GlobalMemoryBuffer, is_torch_min_version +logger = logging.getLogger(__name__) + try: import einops @@ -843,7 +846,7 @@ def initialize_model_parallel( # Apply SHARP to the dp group. if sharp_enabled_group == "dp": if rank == 0: - print( + logger.info( "The number of process groups to use SHARP with depends on the type " "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 " "process groups and QM2 supports up to 256 process groups. We apply " diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index d69398d1fc3..53e5dd61bd6 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -8,6 +8,7 @@ import torch from megatron.core import parallel_state +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator @@ -55,17 +56,23 @@ class HybridCPDataLoaderWrapper: Args: data_iterator: The original data_iterator to wrap around config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. """ - def __init__(self, data_iterator, config): + def __init__(self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None,): self.data_iterator = data_iterator self.config = config self.cp_balancing_scheduler = BalancedCPScheduler( max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank ) - self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( - with_context_parallel=True - ) + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.dp_cp_group = pg_collection.dp_cp + self.dp_group = pg_collection.dp + self.tp_group = pg_collection.tp + assert self.dp_cp_group is not None and self.dp_group is not None, "dp_cp_group and dp_group not found in pg_collection" + + self.total_hdp_gpus = self.dp_cp_group.size() def __iter__(self): """Return self as an iterator.""" @@ -84,10 +91,10 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() dp_subsample_count = [ torch.zeros_like(local_len) - for _ in range(parallel_state.get_data_parallel_world_size()) + for _ in range(self.dp_group.size()) ] torch.distributed.all_gather( - dp_subsample_count, local_len, group=parallel_state.get_data_parallel_group() + dp_subsample_count, local_len, group=self.dp_group ) # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length @@ -108,12 +115,12 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: # Gather the subsample_seqlens from all ranks seqlens_gathered = [ torch.empty_like(subsample_seqlens_padded) - for _ in range(parallel_state.get_data_parallel_world_size()) + for _ in range(self.dp_group.size()) ] torch.distributed.all_gather( seqlens_gathered, subsample_seqlens_padded, - group=parallel_state.get_data_parallel_group(), + group=self.dp_group, ) # Trim each seqlens_gathered to the length of the correct sample @@ -139,7 +146,7 @@ def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered) global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. global_ids_this_rank: list of global IDs locally present on this rank. """ - dp_rank = parallel_state.get_data_parallel_rank() + dp_rank = self.dp_group.rank() global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() # Create a list of (global_id, seqlen) tuples for scheduling global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] @@ -155,10 +162,10 @@ def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: # Since the torch.distributed.get_process_group_ranks # provides the global rank, we need to consider TP hdp_rank = ( - torch.distributed.get_process_group_ranks(parallel_state.get_data_parallel_group())[ + torch.distributed.get_process_group_ranks(self.dp_group)[ dp_src_rank ] - // parallel_state.get_tensor_model_parallel_world_size() + // self.tp_group.size() ) return hdp_rank @@ -174,13 +181,13 @@ def reroute_samples_to_hdp_ranks( to transfer data between matching CP ranks. """ gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + hdp_rank = self.dp_cp_group.rank() dp_ranks = torch.distributed.get_process_group_ranks( - parallel_state.get_data_parallel_group() + self.dp_group ) # Here we actually want to get the DP group's rank within the HDP group, # we need to consider TP - dp_ranks = [r // parallel_state.get_tensor_model_parallel_world_size() for r in dp_ranks] + dp_ranks = [r // self.tp_group.size() for r in dp_ranks] data_keys = batch[0].keys() @@ -264,7 +271,7 @@ def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): input=send_tensor, output_split_sizes=recv_lens_split, input_split_sizes=send_lens_split, - group=parallel_state.get_data_parallel_group(with_context_parallel=True), + group=self.dp_cp_group, ) _unpack_sample_by_key(key, recv_tensor) @@ -357,8 +364,7 @@ def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): This function is used to estimate the relative workload intensity of a sub-sample. This is not meant to be an accurate flops calculator. - Returns: - workload: workload of a sub-sample + Returns: workload of a sub-sample """ if cp_size is None: cp_size = self.gpus_needed(seq_length) @@ -370,6 +376,9 @@ def gpus_needed(self, seq_len: int) -> int: Calculates the number of GPUs needed for a given sequence length and max sequence length per CP rank. This is used to determine the CP size of a sub-sample. + + The number is rounded up to the next power of 2 to match the available + hybrid context parallel process group sizes. """ return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) @@ -379,7 +388,7 @@ def make_buckets_equal( compute_estimator: Callable[[int], float], ) -> List[deque]: """ - Modified version of make_buckets_equal_work that works with (sample_id, seq_len) tuples. + Makes as many buckets as unique CP sizes needed. This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. """ # Extract just the sequence lengths for determining k @@ -388,8 +397,8 @@ def make_buckets_equal( # Determine k based on unique GPU categories needed k = len({self.gpus_needed(L) for L in seqlens}) - # Use the existing contiguous_equal_buckets function but with sample_seqlens - # We need to modify it to work with tuples + # Create a work target for each bucket + # This is the total work divided by the number of buckets work = [] for _, s in sample_seqlens: cp_size = self.gpus_needed(s) @@ -462,7 +471,7 @@ def next_hdp_group( [[] for _ in range(total_gpus)], ) - # Use the improved bucketing that works with (sample_id, seq_len) tuples + # Get buckets of sequences with balanced work buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) # Initialize tracking structures @@ -510,6 +519,7 @@ def next_hdp_group( if sample_seq_tuple is None: break + # TODO[pmannan]: PP not yet supported. Add PP scheduling. if strategy == "pp": pp_cursor = (bucket_idx + 1) % len(buckets) @@ -544,11 +554,9 @@ def next_hdp_group( else: chosen_members = group_members[best_gid] else: - if best_gid is None: - print(f"No room to form a new group") chosen_members = group_members[best_gid] - # ---- Step 2b – if we decided to create a fresh group ---------------- + # ---- Step 2 – if we decided to create a fresh group ---------------- if best_gid is None: best_gid = next_gid next_gid += 1 @@ -625,6 +633,11 @@ def trim_overload(): cur_min = min(exec_times) cur_slack = cur_max - cur_min if cur_slack <= delta * cur_max: + # Slack is already within limit. + break + if cur_min == 0: + # There are empty GPUs that will be + # handled in the next step. break max_r = exec_times.index(cur_max) @@ -644,6 +657,7 @@ def trim_overload(): proj_slack = max(proj_times) - min(proj_times) + # Check if trimming the workload helps imbalance if proj_slack < cur_slack: sample_id_to_remove = sample_ids_per_gpu[max_r][-1] for r in members: @@ -656,7 +670,7 @@ def trim_overload(): trim_overload() - # Track work before redistribution + # Track samples in this group before redistribution to empty GPUs total_work_before = sum(len(mb) for mb in micro_batches) # Check for empty GPUs and redistribute work @@ -666,7 +680,8 @@ def fill_empty_gpus( """ Recursively check for empty GPUs and redistribute work by increasing the number of GPUs sharing samples. This ensures all GPUs have work. - GPUs must be allocated consecutively. + GPUs must be allocated consecutively so we may need to push existing + work to other ranks in order to expand samples. """ # Find empty GPUs empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] @@ -762,11 +777,11 @@ def fill_empty_gpus( ) empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - # Assert that no work has been completely removed + # Assert that no sample has been completely removed total_work_after = sum(len(mb) for mb in micro_batches) assert ( total_work_after >= total_work_before - ), f"Work was removed: {total_work_before} -> {total_work_after}" + ), f"Samples were removed: {total_work_before} -> {total_work_after}" return micro_batches, leftovers, exec_times, sample_ids_per_gpu diff --git a/megatron/training/utils.py b/megatron/training/utils.py index d40bf0be27e..6a9cec07ae4 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -545,7 +545,6 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() - # when enabling hybrid context parallel, cu_seqlens is always None n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) _broadcast(n_tensor) From d3c0c03c5db2f803a92747f0e4fde6610705e409 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 3 Oct 2025 14:13:46 -0700 Subject: [PATCH 55/81] lint fixes --- megatron/core/parallel_state.py | 2 +- .../pipeline_parallel/hybrid_cp_schedule.py | 32 +++++++------------ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 525a131e740..722f4e02706 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -2,8 +2,8 @@ """Model and data parallel groups.""" -import os import logging +import os import warnings from datetime import timedelta from math import log2 diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 53e5dd61bd6..31e7c9d4d22 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -59,7 +59,9 @@ class HybridCPDataLoaderWrapper: dp_cp_group: Data parallel context parallel group. """ - def __init__(self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None,): + def __init__( + self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None + ): self.data_iterator = data_iterator self.config = config self.cp_balancing_scheduler = BalancedCPScheduler( @@ -70,7 +72,9 @@ def __init__(self, data_iterator, config, pg_collection: Optional[ProcessGroupCo self.dp_cp_group = pg_collection.dp_cp self.dp_group = pg_collection.dp self.tp_group = pg_collection.tp - assert self.dp_cp_group is not None and self.dp_group is not None, "dp_cp_group and dp_group not found in pg_collection" + assert ( + self.dp_cp_group is not None and self.dp_group is not None + ), "dp_cp_group and dp_group not found in pg_collection" self.total_hdp_gpus = self.dp_cp_group.size() @@ -89,13 +93,8 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: """ # Collect the number of subsamples from all ranks local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() - dp_subsample_count = [ - torch.zeros_like(local_len) - for _ in range(self.dp_group.size()) - ] - torch.distributed.all_gather( - dp_subsample_count, local_len, group=self.dp_group - ) + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) @@ -114,13 +113,10 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: # Gather the subsample_seqlens from all ranks seqlens_gathered = [ - torch.empty_like(subsample_seqlens_padded) - for _ in range(self.dp_group.size()) + torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) ] torch.distributed.all_gather( - seqlens_gathered, - subsample_seqlens_padded, - group=self.dp_group, + seqlens_gathered, subsample_seqlens_padded, group=self.dp_group ) # Trim each seqlens_gathered to the length of the correct sample @@ -162,9 +158,7 @@ def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: # Since the torch.distributed.get_process_group_ranks # provides the global rank, we need to consider TP hdp_rank = ( - torch.distributed.get_process_group_ranks(self.dp_group)[ - dp_src_rank - ] + torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] // self.tp_group.size() ) return hdp_rank @@ -182,9 +176,7 @@ def reroute_samples_to_hdp_ranks( """ gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} hdp_rank = self.dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks( - self.dp_group - ) + dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) # Here we actually want to get the DP group's rank within the HDP group, # we need to consider TP dp_ranks = [r // self.tp_group.size() for r in dp_ranks] From 1712980dabc072fe0125e93d5eb4d73bcac551bb Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 3 Oct 2025 16:04:52 -0700 Subject: [PATCH 56/81] Fix MLA + Yarn calls --- .../common/embeddings/yarn_rotary_pos_embedding.py | 11 ++++++----- megatron/core/transformer/multi_latent_attention.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index c606dda3bb3..c77ffb298f2 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -182,13 +182,13 @@ def forward( emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb, _mscale - def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): + def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq_params=None): self.max_seq_len_cached = seq_len self.offset_cached = offset self.dtype_cached = dtype - self.packed_seq_cached = packed_seq + self.packed_seq_cached = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - emb, _mscale = self.forward(seq_len, offset, packed_seq) + emb, _mscale = self.forward(seq_len, offset, packed_seq_params) self.register_buffer( "cos_cached", (emb.cos() * _mscale).to(dtype).contiguous(), persistent=False ) @@ -197,16 +197,17 @@ def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): ) def get_cached_cos_sin( - self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq=False + self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq_params=None ): """Get cached cos and sin values.""" + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if ( seq_len > self.max_seq_len_cached or offset != self.offset_cached or dtype != self.dtype_cached or packed_seq != self.packed_seq_cached ): - self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq) + self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq_params) return (self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a252227074f..60ba361142b 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -505,7 +505,7 @@ def get_query_key_value_tensors( else: if self.config.apply_rope_fusion: rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( - rotary_seq_len, dtype=hidden_states.dtype, packed_seq=packed_seq + rotary_seq_len, dtype=hidden_states.dtype, packed_seq_params=packed_seq_params ) rotary_pos_emb = None assert inference_context is None, "Inference with MLA RoPE fusion is not supported" From 62817e820514a67d147ec1bc664cba42c210645f Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Fri, 3 Oct 2025 23:39:12 +0000 Subject: [PATCH 57/81] chore: Format files --- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index c77ffb298f2..7ea398da51f 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -186,7 +186,9 @@ def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq_params=None): self.max_seq_len_cached = seq_len self.offset_cached = offset self.dtype_cached = dtype - self.packed_seq_cached = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + self.packed_seq_cached = ( + packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + ) emb, _mscale = self.forward(seq_len, offset, packed_seq_params) self.register_buffer( From e586fa43cb4d5a82c4f645e707c8a7ba6e1024dc Mon Sep 17 00:00:00 2001 From: kunlunl Date: Thu, 9 Oct 2025 02:40:16 -0700 Subject: [PATCH 58/81] ProcessGroupCollection.use_mpu_process_groups() doesn't initialize dp_cp group --- .../core/pipeline_parallel/hybrid_cp_schedule.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 31e7c9d4d22..e2e6d8dc4d5 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -68,13 +68,16 @@ def __init__( max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank ) if pg_collection is None: - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - self.dp_cp_group = pg_collection.dp_cp - self.dp_group = pg_collection.dp - self.tp_group = pg_collection.tp + self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + self.dp_group = parallel_state.get_data_parallel_group() + self.tp_group = parallel_state.get_tensor_model_parallel_group() + else: + self.dp_cp_group = pg_collection.dp_cp + self.dp_group = pg_collection.dp + self.tp_group = pg_collection.tp assert ( self.dp_cp_group is not None and self.dp_group is not None - ), "dp_cp_group and dp_group not found in pg_collection" + ), "dp_cp_group and dp_group not found" self.total_hdp_gpus = self.dp_cp_group.size() From 4221bdc9ca3130c042de211a6f3b546a3584570c Mon Sep 17 00:00:00 2001 From: kunlunl Date: Thu, 9 Oct 2025 02:43:01 -0700 Subject: [PATCH 59/81] Fix num_microbatches always being 0 --- .../pipeline_parallel/hybrid_cp_schedule.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index e2e6d8dc4d5..c4b92427561 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -898,8 +898,7 @@ def _get_new_data_iterator(sample_id_in_group, group_id): num_samples_this_group = num_samples_this_group.cpu().numpy() num_total_groups = num_samples_this_group.shape[0] - # TODO: How does this variable affect downstream logic? - num_microbatches = 1 + current_microbatch = 0 # Upto last group, we don't need any sync. with no_sync_func(): @@ -920,10 +919,11 @@ def _get_new_data_iterator(sample_id_in_group, group_id): config, collect_non_loss_data, is_first_microbatch=check_first_val_step( - first_val_step, forward_only, num_microbatches == 1 + first_val_step, forward_only, current_microbatch == 0 ), - current_microbatch=num_microbatches - 1, + current_microbatch=current_microbatch, ) + current_microbatch += 1 total_num_tokens += num_tokens.item() if not forward_only: backward_step( @@ -953,10 +953,11 @@ def _get_new_data_iterator(sample_id_in_group, group_id): config, collect_non_loss_data, is_first_microbatch=check_first_val_step( - first_val_step, forward_only, num_microbatches == 1 + first_val_step, forward_only, current_microbatch == 0 ), - current_microbatch=num_microbatches - 1, + current_microbatch=current_microbatch, ) + current_microbatch += 1 total_num_tokens += num_tokens.item() if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) @@ -975,9 +976,9 @@ def _get_new_data_iterator(sample_id_in_group, group_id): config, collect_non_loss_data, is_first_microbatch=check_first_val_step( - first_val_step, forward_only, num_microbatches == 1 + first_val_step, forward_only, current_microbatch == 0 ), - current_microbatch=num_microbatches - 1, + current_microbatch=current_microbatch, ) total_num_tokens += num_tokens.item() if not forward_only: From 0856c61228858932b1f8ccc70688b2840043d03b Mon Sep 17 00:00:00 2001 From: kunlunl Date: Thu, 9 Oct 2025 02:45:22 -0700 Subject: [PATCH 60/81] Raise error when not using per-token loss --- megatron/training/arguments.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index f5a7ae7ee5b..34dd3d3aa32 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -903,12 +903,12 @@ def validate_args(args, defaults={}): if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - assert not (args.hybrid_context_parallel - and (args.pipeline_model_parallel_size > 1 - or args.enable_cuda_graph - or args.use_megatron_fsdp - or args.dataloader_type != 'single') - ), "Hybrid context parallelism not supported with pipeline parallelism, CUDA Graph, FSDP and non-single dataloader type now" + if args.hybrid_context_parallel: + assert not args.pipeline_model_parallel_size > 1, 'Hybrid context parallelism not supported with pipeline parallelism' + assert not args.enable_cuda_graph, 'Hybrid context parallelism not supported with CUDA Graph' + assert not args.use_megatron_fsdp, 'Hybrid context parallelism not supported with Megatron FSDP' + assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' + assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled From d2db69ecf2038d3fd49f02fe0e16c6aa8e11632c Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 9 Oct 2025 15:21:53 -0700 Subject: [PATCH 61/81] Minor fix --- megatron/core/pipeline_parallel/hybrid_cp_schedule.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index c4b92427561..96b6b86e7c4 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -76,8 +76,8 @@ def __init__( self.dp_group = pg_collection.dp self.tp_group = pg_collection.tp assert ( - self.dp_cp_group is not None and self.dp_group is not None - ), "dp_cp_group and dp_group not found" + self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" self.total_hdp_gpus = self.dp_cp_group.size() @@ -984,6 +984,4 @@ def _get_new_data_iterator(sample_id_in_group, group_id): if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - torch.distributed.barrier(parallel_state.get_data_parallel_group(with_context_parallel=True)) - return forward_data_store, total_num_tokens From 627753b8974ee0873d1d9dd43baa60733d530230 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 15 Oct 2025 18:21:43 -0700 Subject: [PATCH 62/81] Address comments --- megatron/core/datasets/gpt_dataset.py | 10 ++- megatron/core/datasets/utils.py | 2 +- megatron/core/model_parallel_config.py | 8 +- .../pipeline_parallel/hybrid_cp_schedule.py | 32 ------- megatron/core/utils.py | 84 +++++++++++++++++++ megatron/legacy/data/data_samplers.py | 6 ++ pretrain_gpt.py | 62 +------------- 7 files changed, 108 insertions(+), 96 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 959f5ef6d28..ef4ac923bf7 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -59,10 +59,16 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): """Option to enable data parallelism""" sequence_parallel_size: int = 0 - """Option to indicate the sequence parallelism size when using TP""" + """Option to indicate the sequence parallelism size when using TP + Set to 0 if sequence parallel is not enabled regardless of TP size. + """ hybrid_context_parallel: bool = False - """Option to enable hybrid context parallelism""" + """Option to enable hybrid context parallelism. When setting this to True, + each sample should be divisible by the data parallel size * context parallel size * 2. + If sequence parallel is enabled, it should be divisible by the + data parallel size * context parallel size * sequence parallel size * 2. + """ def __post_init__(self) -> None: """Do asserts and set fields post init""" diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py index e14656df799..656f969aff2 100644 --- a/megatron/core/datasets/utils.py +++ b/megatron/core/datasets/utils.py @@ -89,4 +89,4 @@ def get_blend_from_list( prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset] - return prefix_per_dataset, weight_per_dataset + return prefix_per_dataset, weight_per_dataset \ No newline at end of file diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index c2eba8f594f..fcbfcac208f 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -54,14 +54,18 @@ class ModelParallelConfig: max_seqlen_per_dp_cp_rank: Optional[int] = None """ - Maximum sequence length per DPxCP rank. This is used to calculate the number and length - of sub-samples assigned to each CP rank when using hybrid_context_parallel. + Maximum sequence length per DPxCP rank. This is the maximum sequence length each rank + can handle without overflowing the memory. Typically, a good starting point is to set this + to maximum sequence length / context parallel size. + This is used to calculate the number and length of sub-samples assigned to + each rank when using hybrid_context_parallel. """ hybrid_context_parallel: bool = False """ If true, enables hybrid context parallel. This is used to balance the workload of each CP rank when we use packed samples with variable sequence lengths. + Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ expert_model_parallel_size: int = 1 diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 96b6b86e7c4..f2f8caffc8a 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -11,38 +11,6 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator - -class HybridCPDatasetWrapper: - """ - A wrapper class that wraps around any existing dataset. - It adds batch_idx to the sample and returns the sample. - - This is required if certains DP ranks need the dataset idx - to access the appropriate sub-sample assigned to them from - the file system even if it was not originally assigned to them. - - Args: - dataset: The original dataset to wrap around - """ - - def __init__(self, dataset): - self.dataset = dataset - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx) -> Any: - """ - Get the next item from the dataset, add batch_idx and return it. - """ - sample = self.dataset[idx] - sample["batch_idx"] = idx - assert ( - "cu_seqlens" in sample - ), "cu_seqlens must be in the sample to use hybrid context parallel" - return sample - - class HybridCPDataLoaderWrapper: """ A wrapper class that wraps around an existing data_iterator. diff --git a/megatron/core/utils.py b/megatron/core/utils.py index d27eef1c8e6..8412c325207 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1855,6 +1855,90 @@ def get_batch_on_this_cp_rank( return batch +def get_thd_batch_on_this_cp_rank( + batch: Dict[str, Any], cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, cp_size: Optional[int] = None, cp_rank: Optional[int] = None +): + """Slice each sub-sample in a packed sample batch input along + sequence dimension into multiple chunks, which are parallelized + across GPUs in a context parallel group. + """ + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, + max_seqlen_q=int(max_seqlen[0].item()), + max_seqlen_kv=int(max_seqlen[0].item()), + ) + + cp_size = get_context_parallel_world_size() if cp_size is None else cp_size + cp_rank = get_context_parallel_rank() if cp_rank is None else cp_rank + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + index = tex.thd_get_partitioned_indices( + cu_seqlens, + batch['tokens'].size(1), + cp_size, + cp_rank, + ) + for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + continue + batch[key] = data.index_select(1, index) + + return batch, packed_seq_params + + +################################ +### hybrid context parallel ### +################################ + +def get_batch_on_this_hybrid_cp_rank( + batch: Dict[str, Any], local_cp_size: int, cp_group: Optional[torch.distributed.ProcessGroup] = None +): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + assert local_cp_size is not None + if cp_group is None: + # Get the local cp group required for this sample as defined by the Hybrid CP dataloader wrapper + if local_cp_size > 1: + cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) + else: + # If cp group is provided, it must match the local cp size defined by the Hybrid CP dataloader wrapper + assert cp_group.size() == local_cp_size + + # Convert [seqlen] to [1, seqlen] similar to default collate_fn + # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn + for key, data in batch.items(): + if key in ['attention_mask']: + continue + batch[key] = torch.stack([data], 0) + sample_length = batch['tokens'].shape[1] + # Create packed_seq_params for SBHD format with cp group information. + packed_seq_params = PackedSeqParams( + qkv_format="sbhd", + cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + max_seqlen_q=sample_length, + max_seqlen_kv=sample_length, + local_cp_size=local_cp_size, + cp_group=cp_group, + ) + + if cp_group is not None and cp_group.size() > 1: + # When using hybrid_context_parallel, each sub-sample of a packed sample is + # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) + batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) + + return batch, packed_seq_params + ###################### ### NVTX profiling ### ###################### diff --git a/megatron/legacy/data/data_samplers.py b/megatron/legacy/data/data_samplers.py index bb6a36e3903..79bdc7b193f 100644 --- a/megatron/legacy/data/data_samplers.py +++ b/megatron/legacy/data/data_samplers.py @@ -129,6 +129,12 @@ def __iter__(self): yield batch[start_idx:end_idx] class HybridCPMegatronPretrainingSampler(MegatronPretrainingSampler): + """ + Data sampler for hybrid context parallel (Hybrid CP) format. + This data sampler pulls in the entire global batch at once across all data parallel ranks. + This helps provide the Hybrid CP Dataloader Wrapper to schedule and load balance sub-samples + of the entire global batch. + """ def __init__(self, total_samples, consumed_samples, micro_batch_size, global_batch_size, data_parallel_rank, data_parallel_size, drop_last=True): diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 7358e53f75d..0e7778d7f6a 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -18,7 +18,7 @@ from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.utils import get_attr_wrapped_model, is_te_min_version, StragglerDetector +from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_hybrid_cp_rank, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.utils import ( @@ -74,65 +74,9 @@ def get_batch(data_iterator, vp_stage=None): elif local_cp_size is None: # Packed THD format cu_seqlens = cu_seqlens[0] assert max_seqlen.dim() == 1 - - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens, - cu_seqlens_kv_padded=cu_seqlens, - max_seqlen_q=int(max_seqlen[0].item()), - max_seqlen_kv=int(max_seqlen[0].item()), - ) - - cp_size = get_context_parallel_world_size() - if cp_size > 1: # slice batch along sequence dimension for context parallelism - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) - cp_rank = get_context_parallel_rank() - index = tex.thd_get_partitioned_indices( - cu_seqlens, - batch['tokens'].size(1), - cp_size, - cp_rank, - ) - for key, data in batch.items(): - if key in {'attention_mask'}: - continue - batch[key] = data.index_select(1, index) + batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, max_seqlen) else: # Hybrid CP format - assert local_cp_size is not None - if local_cp_size > 1: - cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) - else: - cp_group = None - - # Convert [seqlen] to [1, seqlen] similar to default collate_fn - # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn - for key, data in batch.items(): - if key in ['attention_mask']: - continue - batch[key] = torch.stack([data], 0) - sample_length = batch['tokens'].shape[1] - # Create packed_seq_params for SBHD format with cp group information. - packed_seq_params = PackedSeqParams( - qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - max_seqlen_q=sample_length, - max_seqlen_kv=sample_length, - local_cp_size=local_cp_size, - cp_group=cp_group, - ) - - if cp_group is not None and cp_group.size() > 1: - # When using hybrid_context_parallel, each sub-sample of a packed sample is - # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) + batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size) return (*batch.values(), packed_seq_params) From a3b5ebefcf3895386959dcce829d3b79d1cd0218 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 15 Oct 2025 18:27:19 -0700 Subject: [PATCH 63/81] Lint --- .../pipeline_parallel/hybrid_cp_schedule.py | 1 + megatron/core/utils.py | 38 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index f2f8caffc8a..26ec6d4f3e1 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -11,6 +11,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator + class HybridCPDataLoaderWrapper: """ A wrapper class that wraps around an existing data_iterator. diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 8412c325207..c3cea956411 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1856,10 +1856,14 @@ def get_batch_on_this_cp_rank( def get_thd_batch_on_this_cp_rank( - batch: Dict[str, Any], cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, cp_size: Optional[int] = None, cp_rank: Optional[int] = None + batch: Dict[str, Any], + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + cp_size: Optional[int] = None, + cp_rank: Optional[int] = None, ): - """Slice each sub-sample in a packed sample batch input along - sequence dimension into multiple chunks, which are parallelized + """Slice each sub-sample in a packed sample batch input along + sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ packed_seq_params = PackedSeqParams( @@ -1880,10 +1884,7 @@ def get_thd_batch_on_this_cp_rank( "Context Parallel with THD format data" ) index = tex.thd_get_partitioned_indices( - cu_seqlens, - batch['tokens'].size(1), - cp_size, - cp_rank, + cu_seqlens, batch['tokens'].size(1), cp_size, cp_rank ) for key, data in batch.items(): if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: @@ -1897,21 +1898,27 @@ def get_thd_batch_on_this_cp_rank( ### hybrid context parallel ### ################################ + def get_batch_on_this_hybrid_cp_rank( - batch: Dict[str, Any], local_cp_size: int, cp_group: Optional[torch.distributed.ProcessGroup] = None + batch: Dict[str, Any], + local_cp_size: int, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ assert local_cp_size is not None if cp_group is None: - # Get the local cp group required for this sample as defined by the Hybrid CP dataloader wrapper + # Get the local cp group required for as defined by the HybridCPDataLoaderWrapper if local_cp_size > 1: - cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=local_cp_size) + cp_group = parallel_state.get_hybrid_data_context_parallel_groups( + group_size=local_cp_size + ) else: - # If cp group is provided, it must match the local cp size defined by the Hybrid CP dataloader wrapper + # If cp group is provided, it must match the local cp size + # as defined by the HybridCPDataLoaderWrapper assert cp_group.size() == local_cp_size - + # Convert [seqlen] to [1, seqlen] similar to default collate_fn # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn for key, data in batch.items(): @@ -1933,12 +1940,15 @@ def get_batch_on_this_hybrid_cp_rank( ) if cp_group is not None and cp_group.size() > 1: - # When using hybrid_context_parallel, each sub-sample of a packed sample is + # When using hybrid_context_parallel, each sub-sample of a packed sample is # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank(batch, cp_group.size(), torch.distributed.get_rank(group=cp_group)) + batch = get_batch_on_this_cp_rank( + batch, cp_group.size(), torch.distributed.get_rank(group=cp_group) + ) return batch, packed_seq_params + ###################### ### NVTX profiling ### ###################### From 1c22f5c29df98a7c7edd22fe44adbcf46d76cab5 Mon Sep 17 00:00:00 2001 From: Mcore Bot Date: Thu, 16 Oct 2025 01:38:24 +0000 Subject: [PATCH 64/81] chore: Format files --- megatron/core/datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py index 656f969aff2..e14656df799 100644 --- a/megatron/core/datasets/utils.py +++ b/megatron/core/datasets/utils.py @@ -89,4 +89,4 @@ def get_blend_from_list( prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset] - return prefix_per_dataset, weight_per_dataset \ No newline at end of file + return prefix_per_dataset, weight_per_dataset From 027f46f6bbaf77dc5c1d90d6d137bde1d9548630 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 7 Nov 2025 16:18:11 -0800 Subject: [PATCH 65/81] Fix import --- megatron/core/utils.py | 8 ++++++++ pretrain_gpt.py | 10 ---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/megatron/core/utils.py b/megatron/core/utils.py index c3cea956411..233306f8844 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -57,6 +57,14 @@ logger = logging.getLogger(__name__) +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None try: _torch_version = PkgVersion(torch.__version__) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 0e7778d7f6a..25b579091d0 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -39,16 +39,6 @@ except ImportError: has_nvidia_modelopt = False -try: - # Register the TE CUDA kernels - import transformer_engine # pylint: disable=unused-import - - # Alias the PyTorch wrapper so we can call tex.* APIs - import transformer_engine_torch as tex -except ImportError: - # TE isn’t installed or the torch wrapper is missing - tex = None - stimer = StragglerDetector() From 11ed9dbb8a9f152586e8644943f8332a9a6ffc8d Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 17 Nov 2025 15:22:46 -0800 Subject: [PATCH 66/81] Move Dataloader wrapper to core/datasets --- megatron/core/datasets/data_schedule.py | 299 +++++++++++++++++ .../pipeline_parallel/hybrid_cp_schedule.py | 300 +----------------- megatron/training/training.py | 2 +- 3 files changed, 302 insertions(+), 299 deletions(-) create mode 100644 megatron/core/datasets/data_schedule.py diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py new file mode 100644 index 00000000000..93e22daaf04 --- /dev/null +++ b/megatron/core/datasets/data_schedule.py @@ -0,0 +1,299 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Any, List, Optional +import torch +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler +from megatron.core.process_groups_config import ProcessGroupCollection + +class HybridCPDataLoaderWrapper: + """ + A wrapper class that wraps around an existing data_iterator. + For every __next__ call, + 1. Each DP rank pulls a batch of packed samples. + 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. + 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. + 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. + 5. Returns the assigned sub-samples to this rank. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + """ + + def __init__( + self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None + ): + self.data_iterator = data_iterator + self.config = config + if pg_collection is None: + self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + self.dp_group = parallel_state.get_data_parallel_group() + self.tp_group = parallel_state.get_tensor_model_parallel_group() + else: + self.dp_cp_group = pg_collection.dp_cp + self.dp_group = pg_collection.dp + self.tp_group = pg_collection.tp + assert ( + self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" + + self.cp_balancing_scheduler = BalancedCPScheduler( + max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, + dp_cp_group=self.dp_cp_group + ) + + self.total_hdp_gpus = self.dp_cp_group.size() + + def __iter__(self): + """Return self as an iterator.""" + return self + + def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + Each DP rank loads the same number of microbatches but each microbatch + may have a different number of subsamples. + + We find the number of subsamples each rank holds and then gather the + sequence lengths of all subsamples from all ranks. + """ + # Collect the number of subsamples from all ranks + local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if local_len.item() < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [ + torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) + ] + torch.distributed.all_gather( + seqlens_gathered, subsample_seqlens_padded, group=self.dp_group + ) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) + + return seqlens_gathered, offsets + + def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ + dp_rank = self.dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] + + return global_id_seqlens, global_ids_this_rank + + def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + # Since the torch.distributed.get_process_group_ranks + # provides the global rank, we need to consider TP + hdp_rank = ( + torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] + // self.tp_group.size() + ) + return hdp_rank + + def reroute_samples_to_hdp_ranks( + self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + Since all CP ranks within a DP group have the same data, we only need + to transfer data between matching CP ranks. + """ + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + hdp_rank = self.dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) + # Here we actually want to get the DP group's rank within the HDP group, + # we need to consider TP + dp_ranks = [r // self.tp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] + + for d in range(self.total_hdp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + + for dest_rank in range(self.total_hdp_gpus): + combined_sample_id_groups[dest_rank].sort() + + # Filter out samples that are not present on this rank + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] + # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + send_lens_split = [0] * self.total_hdp_gpus + for dest_rank in range(self.total_hdp_gpus): + if dest_rank in dp_ranks: + send_lens_split[dest_rank] = sum( + [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + ) + else: + # We only need to share local data with DP ranks that have different data. + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] + for gid in combined_sample_id_groups[hdp_rank]: + src_rank = self._gid_to_src_rank(gid, offsets) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * self.total_hdp_gpus + for src_rank in range(self.total_hdp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [ + gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] + ] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = global_id_seqlens[gid][1] + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + send_tensor = _pack_sample_by_key(key) + recv_tensor = torch.empty( + sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=recv_lens_split, + input_split_sizes=send_lens_split, + group=self.dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } + return recv_sample_with_id + + def unpack_batch(self, batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + for sample in batch: + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in sample.keys(): + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: + continue + sub_sample_dict[key] = sample[key][start_idx:end_idx] + batch_unpacked.append(sub_sample_dict) + return batch_unpacked + + def __next__(self) -> Any: + """ + Get the next item from the dataset, pull scheduling metadata and return it. + """ + if self.data_iterator is None: + # TP0 reads from data_iterator, others receive via broadcast. + return None, None + else: + batch = next(self.data_iterator) + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend( + [ + int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) + for i in range(0, sample["cu_seqlens"].shape[0] - 1) + ] + ) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] + + seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) + + global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered + ) + + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( + global_id_seqlens, self.config + ) + + batch = self.unpack_batch(batch) + samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( + batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ) + return samples_this_rank_with_id, sample_id_groups \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index 26ec6d4f3e1..c7ca94c5aa9 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -8,315 +8,21 @@ import torch from megatron.core import parallel_state -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator -class HybridCPDataLoaderWrapper: - """ - A wrapper class that wraps around an existing data_iterator. - For every __next__ call, - 1. Each DP rank pulls a batch of packed samples. - 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. - 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. - 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. - 5. Returns the assigned sub-samples to this rank. - - Args: - data_iterator: The original data_iterator to wrap around - config: The config object containing the max_seqlen_per_dp_cp_rank - dp_cp_group: Data parallel context parallel group. - """ - - def __init__( - self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None - ): - self.data_iterator = data_iterator - self.config = config - self.cp_balancing_scheduler = BalancedCPScheduler( - max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank - ) - if pg_collection is None: - self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - self.dp_group = parallel_state.get_data_parallel_group() - self.tp_group = parallel_state.get_tensor_model_parallel_group() - else: - self.dp_cp_group = pg_collection.dp_cp - self.dp_group = pg_collection.dp - self.tp_group = pg_collection.tp - assert ( - self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None - ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" - - self.total_hdp_gpus = self.dp_cp_group.size() - - def __iter__(self): - """Return self as an iterator.""" - return self - - def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: - """ - Gathers the sequence lengths of all subsamples from all DP ranks. - Each DP rank loads the same number of microbatches but each microbatch - may have a different number of subsamples. - - We find the number of subsamples each rank holds and then gather the - sequence lengths of all subsamples from all ranks. - """ - # Collect the number of subsamples from all ranks - local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() - dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] - torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) - - # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length - dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) - max_sub_samples = int(dp_subsample_counts.max().item()) - - if local_len.item() < max_sub_samples: - subsample_seqlens_padded = torch.cat( - [ - subsample_seqlens, - torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), - ], - dim=0, - ) - else: - subsample_seqlens_padded = subsample_seqlens - - # Gather the subsample_seqlens from all ranks - seqlens_gathered = [ - torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) - ] - torch.distributed.all_gather( - seqlens_gathered, subsample_seqlens_padded, group=self.dp_group - ) - - # Trim each seqlens_gathered to the length of the correct sample - for dp_rank, seqlen in enumerate(seqlens_gathered): - seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] - - seqlens_gathered = torch.cat(seqlens_gathered, dim=0) - seqlens_gathered = seqlens_gathered.cpu().tolist() - - # Calculate the offsets to assign unique global ID to each subsample. - csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) - offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) - - return seqlens_gathered, offsets - - def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): - """ - Calculates the global ID for each subsample. - - We assign a unique global ID to each subsample. - - Returns: - global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. - global_ids_this_rank: list of global IDs locally present on this rank. - """ - dp_rank = self.dp_group.rank() - global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() - # Create a list of (global_id, seqlen) tuples for scheduling - global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] - # Get the global IDs locally present on this rank - global_ids_this_rank = global_ids[ - offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples - ] - - return global_id_seqlens, global_ids_this_rank - - def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: - dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) - # Since the torch.distributed.get_process_group_ranks - # provides the global rank, we need to consider TP - hdp_rank = ( - torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] - // self.tp_group.size() - ) - return hdp_rank - - def reroute_samples_to_hdp_ranks( - self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets - ): - """ - Reroutes the sub-samples to the correct rank after scheduling. - - For each key in the batch dict, we perform an all-to-all communication - to transfer the data to the correct ranks. - Since all CP ranks within a DP group have the same data, we only need - to transfer data between matching CP ranks. - """ - gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - hdp_rank = self.dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) - # Here we actually want to get the DP group's rank within the HDP group, - # we need to consider TP - dp_ranks = [r // self.tp_group.size() for r in dp_ranks] - - data_keys = batch[0].keys() - - # Create the send plan - combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] - - for d in range(self.total_hdp_gpus): - for sample_id_group in sample_id_groups: - combined_sample_id_groups[d].extend(sample_id_group[d]) - - for dest_rank in range(self.total_hdp_gpus): - combined_sample_id_groups[dest_rank].sort() - - # Filter out samples that are not present on this rank - send_ids_sorted = [ - gid - for d in dp_ranks - for gid in combined_sample_id_groups[d] - if gid in global_ids_this_rank - ] - # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] - - send_lens_split = [0] * self.total_hdp_gpus - for dest_rank in range(self.total_hdp_gpus): - if dest_rank in dp_ranks: - send_lens_split[dest_rank] = sum( - [ - global_id_seqlens[gid][1] - for gid in combined_sample_id_groups[dest_rank] - if gid in global_ids_this_rank - ] - ) - else: - # We only need to share local data with DP ranks that have different data. - send_lens_split[dest_rank] = 0 - - # Create the recv plan - recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] - for gid in combined_sample_id_groups[hdp_rank]: - src_rank = self._gid_to_src_rank(gid, offsets) - recv_sample_id_groups[src_rank].append(gid) - - recv_lens_split = [0] * self.total_hdp_gpus - for src_rank in range(self.total_hdp_gpus): - recv_lens_split[src_rank] = sum( - [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] - ) - - recv_ids_sorted = [ - gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] - ] - recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] - - recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] - - def _pack_sample_by_key(key: str) -> torch.Tensor: - flattened_tensors = [] - for gid in send_ids_sorted: - t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) - flattened_tensors.append(t) - return ( - torch.cat(flattened_tensors, dim=0) - if flattened_tensors - else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) - ) - - def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): - cursor = 0 - for i, gid in enumerate(recv_ids_sorted): - sample_len = global_id_seqlens[gid][1] - recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] - cursor += sample_len - - for key in data_keys: - send_tensor = _pack_sample_by_key(key) - recv_tensor = torch.empty( - sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype - ) - torch.distributed.all_to_all_single( - output=recv_tensor, - input=send_tensor, - output_split_sizes=recv_lens_split, - input_split_sizes=send_lens_split, - group=self.dp_cp_group, - ) - _unpack_sample_by_key(key, recv_tensor) - - recv_sample_with_id = { - recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) - } - return recv_sample_with_id - - def unpack_batch(self, batch): - """ - Unpacks the packed samples into a list of sub-samples. - Since each sub-sample may be routed to different DPxCP ranks, - we unpack the sample here to avoid unnecessarily transferring - the entire packed sample. - """ - batch_unpacked = [] - for sample in batch: - for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): - sub_sample_dict = {} - start_idx = sample["cu_seqlens"][sub_sample] - end_idx = sample["cu_seqlens"][sub_sample + 1] - if end_idx - start_idx == 0: - continue - for key in sample.keys(): - if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: - continue - sub_sample_dict[key] = sample[key][start_idx:end_idx] - batch_unpacked.append(sub_sample_dict) - return batch_unpacked - - def __next__(self) -> Any: - """ - Get the next item from the dataset, pull scheduling metadata and return it. - """ - if self.data_iterator is None: - # TP0 reads from data_iterator, others receive via broadcast. - return None, None - else: - batch = next(self.data_iterator) - subsample_seqlens = [] - for sample in batch: - subsample_seqlens.extend( - [ - int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) - for i in range(0, sample["cu_seqlens"].shape[0] - 1) - ] - ) - subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() - subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] - - seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) - - global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( - subsample_seqlens.shape[0], offsets, seqlens_gathered - ) - - groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( - global_id_seqlens, self.config - ) - - batch = self.unpack_batch(batch) - samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( - batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets - ) - return samples_this_rank_with_id, sample_id_groups - - class BalancedCPScheduler: """ This class provides the functionality to form groups of sub-samples such that all DPxCP ranks have a roughly balanced workload in the group. """ - def __init__(self, max_seq_len_per_rank: int): + def __init__(self, max_seq_len_per_rank: int, dp_cp_group: torch.distributed.ProcessGroup): self.max_seq_len_per_rank = max_seq_len_per_rank self.num_subsamples = 0 self.num_subsamples_processed = 0 self.free_resources = [] - self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( - with_context_parallel=True - ) + self.total_hdp_gpus = dp_cp_group.size() @lru_cache(maxsize=128) def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): @@ -423,8 +129,6 @@ def next_hdp_group( we redistribute work to empty GPUs by recursively increasing the CP size of a sample until no empty GPUs are left. - #TODO: Add clarification on when we check for balance. What does prev_needed do? - Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). """ if not sample_seqlens: diff --git a/megatron/training/training.py b/megatron/training/training.py index cf9226de4c7..747906eba9e 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -86,13 +86,13 @@ from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.legacy.data.data_samplers import build_pretraining_data_loader +from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.parallel_state import destroy_global_memory_buffer, destroy_model_parallel from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.pipeline_parallel.hybrid_cp_schedule import HybridCPDataLoaderWrapper from megatron.core.num_microbatches_calculator import ( destroy_num_microbatches_calculator, get_current_global_batch_size, From c59e199b52798b2d097df0209cd1ba3dbfff8c0a Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 17 Nov 2025 15:23:34 -0800 Subject: [PATCH 67/81] Formatting --- megatron/core/datasets/data_schedule.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 93e22daaf04..b3f9df64b4b 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -6,6 +6,7 @@ from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection + class HybridCPDataLoaderWrapper: """ A wrapper class that wraps around an existing data_iterator. @@ -38,10 +39,9 @@ def __init__( assert ( self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" - + self.cp_balancing_scheduler = BalancedCPScheduler( - max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, - dp_cp_group=self.dp_cp_group + max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group ) self.total_hdp_gpus = self.dp_cp_group.size() @@ -296,4 +296,4 @@ def __next__(self) -> Any: samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) - return samples_this_rank_with_id, sample_id_groups \ No newline at end of file + return samples_this_rank_with_id, sample_id_groups From c26dd8e538580e579dd2572b935584087a5df210 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 24 Nov 2025 17:07:20 -0800 Subject: [PATCH 68/81] Fixes and address comments Signed-off-by: Parth Mannan --- .../models/common/embeddings/rotary_pos_embedding.py | 8 ++------ .../common/embeddings/yarn_rotary_pos_embedding.py | 8 ++------ megatron/core/utils.py | 10 ++++++---- pretrain_gpt.py | 12 ++++++------ 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index a8f94f46886..c7c452d2f87 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -191,12 +191,8 @@ def forward( emb = self.get_emb(max_seq_len, offset) packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - if packed_seq_params.local_cp_size > 1: - # Set CP group to dynamic CP group for CP slicing - cp_group = packed_seq_params.cp_group - else: - # Set CP group to None to avoid CP slicing - cp_group = None + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group else: cp_group = self.cp_group diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index f3083cbb087..2c75b5ca84f 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -168,12 +168,8 @@ def forward( emb, _mscale = self.get_emb(max_seq_len, offset) packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - if packed_seq_params.local_cp_size > 1: - # Set CP group to dynamic CP group for CP slicing - cp_group = packed_seq_params.cp_group - else: - # Set CP group to None to avoid CP slicing - cp_group = None + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group else: cp_group = self.cp_group if cp_group is not None and cp_group.size() > 1 and not packed_seq: diff --git a/megatron/core/utils.py b/megatron/core/utils.py index d901011ac1b..22138659617 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1943,6 +1943,7 @@ def get_batch_on_this_cp_rank( def get_thd_batch_on_this_cp_rank( batch: Dict[str, Any], cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, max_seqlen: torch.Tensor, cp_size: Optional[int] = None, cp_rank: Optional[int] = None, @@ -1955,8 +1956,8 @@ def get_thd_batch_on_this_cp_rank( qkv_format="thd", cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens, - cu_seqlens_kv_padded=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, max_seqlen_q=int(max_seqlen[0].item()), max_seqlen_kv=int(max_seqlen[0].item()), ) @@ -1969,10 +1970,10 @@ def get_thd_batch_on_this_cp_rank( "Context Parallel with THD format data" ) index = tex.thd_get_partitioned_indices( - cu_seqlens, batch['tokens'].size(1), cp_size, cp_rank + cu_seqlens_padded, batch['tokens'].size(1), cp_size, cp_rank ) for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + if key in {'attention_mask', 'cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'}: continue batch[key] = data.index_select(1, index) @@ -2011,6 +2012,7 @@ def get_batch_on_this_hybrid_cp_rank( continue batch[key] = torch.stack([data], 0) sample_length = batch['tokens'].shape[1] + # TODO(pmannan): Take care of padding tokens here if not divisible by cp_size*2 # Create packed_seq_params for SBHD format with cp group information. packed_seq_params = PackedSeqParams( qkv_format="sbhd", diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 7f69b0e0814..5b46c80da0f 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -41,14 +41,15 @@ def get_batch(data_iterator, vp_stage=None): """Generate a batch.""" # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage): - return None, None, None, None, None + return None, None, None, None, None, None # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator) - cu_seqlens = batch.pop('cu_seqlens') - max_seqlen = batch.pop('max_seqlen') - local_cp_size = batch.pop('local_cp_size') + cu_seqlens = batch.pop('cu_seqlens', None) + cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) + max_seqlen = batch.pop('max_seqlen', None) + local_cp_size = batch.pop('local_cp_size', None) if local_cp_size is not None: local_cp_size = int(local_cp_size.item()) @@ -57,9 +58,8 @@ def get_batch(data_iterator, vp_stage=None): batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore packed_seq_params = None elif local_cp_size is None: # Packed THD format - cu_seqlens = cu_seqlens[0] assert max_seqlen.dim() == 1 - batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, max_seqlen) + batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen) else: # Hybrid CP format batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size) From 7348e73c420e225af15c6ab52fa831f0f42bdc13 Mon Sep 17 00:00:00 2001 From: Kunlun Li <94586211+kunlunl@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:40:14 +0800 Subject: [PATCH 69/81] Fix TP hanging issue --- megatron/training/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 6a9cec07ae4..dd6c8cc46d7 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -635,7 +635,7 @@ def _broadcast_cu_seqlens(cu_seqlens): 1, dtype=torch.int32, device=torch.cuda.current_device(), - ) if not args.hybrid_context_parallel else None + ) if args.hybrid_context_parallel else None local_cp_size = torch.empty( 1, dtype=torch.int32, From b286b8a86f39037c5e68c13f9770799356703632 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 3 Dec 2025 13:22:57 -0800 Subject: [PATCH 70/81] Mamba script fix --- pretrain_mamba.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pretrain_mamba.py b/pretrain_mamba.py index 45b646a6cc0..ca2008620be 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -44,6 +44,13 @@ def get_batch(data_iterator, vp_stage=None): # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator) + + # Support for Packed Sequence (Unused in this script) + cu_seqlens = batch.pop('cu_seqlens', None) + cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) + max_seqlen = batch.pop('max_seqlen', None) + # Support for Hybrid Context Parallel (Unused in this script) + local_cp_size = batch.pop('local_cp_size', None) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) From 809562161aae50fef2c9888363780cb558cb11c7 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Sun, 11 Jan 2026 19:43:03 -0800 Subject: [PATCH 71/81] Lint fix --- megatron/core/extensions/transformer_engine.py | 2 +- megatron/core/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 98f04f95aaa..480ad63cd49 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1355,7 +1355,7 @@ def forward( super().set_context_parallel_group(None, None, None, self.cp_comm_type) self.kept_packed_seq_params.discard("cp_group") self.kept_packed_seq_params.discard("local_cp_size") - + # Default to constructor-provided num_splits unless explicitly overridden if num_splits is None: num_splits = self.num_splits diff --git a/megatron/core/utils.py b/megatron/core/utils.py index f5ae644f34d..5b4ab5cc978 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -76,6 +76,7 @@ try: # Register the TE CUDA kernels import transformer_engine # pylint: disable=unused-import + # Alias the PyTorch wrapper so we can call tex.* APIs import transformer_engine_torch as tex except ImportError: From a1e2df656120345945db843bd46235f227bf9e40 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 13:08:58 -0800 Subject: [PATCH 72/81] Address RotaryEmbedding comments --- .../common/embeddings/rotary_pos_embedding.py | 27 +++++++------------ .../embeddings/yarn_rotary_pos_embedding.py | 15 ++++++----- megatron/core/models/gpt/gpt_model.py | 14 +++++++--- .../pipeline_parallel/hybrid_cp_schedule.py | 2 +- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index c7c452d2f87..d7908e4b185 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -147,7 +147,6 @@ def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): sin = torch.sin(freqs) return cos, sin - @lru_cache(maxsize=32) def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: """Forward pass of RoPE embedding before CP sharding. @@ -175,8 +174,14 @@ def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: emb = emb[:, None, None, :] return emb + @lru_cache(maxsize=32) + @internal_api def forward( - self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + self, + max_seq_len: int, + offset: int = 0, + packed_seq: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Tensor: """Forward pass of RoPE embedding. @@ -189,13 +194,8 @@ def forward( Tensor: Embeddings after applying RoPE. """ emb = self.get_emb(max_seq_len, offset) - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - # Set CP group to dynamic CP group for CP slicing - cp_group = packed_seq_params.cp_group - else: + if cp_group is None: cp_group = self.cp_group - if cp_group is not None and cp_group.size() > 1 and not packed_seq: # slice rotary_pos_emb along sequence dimension # and select the parition of the current CP rank @@ -306,7 +306,7 @@ def forward( self, position_ids: torch.Tensor, mrope_section: List[int], - packed_seq_params: Optional[PackedSeqParams] = None, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Tensor: """Forward pass of multimodal RoPE embedding. @@ -347,14 +347,7 @@ def forward( # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - if packed_seq_params.local_cp_size > 1: - # Set CP group to dynamic CP group for CP slicing - cp_group = packed_seq_params.cp_group - else: - # Set CP group to None to avoid CP slicing - cp_group = None - else: + if cp_group is None: cp_group = self.cp_group if cp_group is not None and cp_group.size() > 1: # slice rotary_pos_emb along sequence dimension and select the parition of the current diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index b1c0f59848b..69031a4584c 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -102,7 +102,6 @@ def __init__( # method causes a memory leak in NeMo-RL. self.forward.cache_clear() - @lru_cache(maxsize=32) def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: """Forward pass of Yarn Rotary Embedding. @@ -156,8 +155,14 @@ def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: emb = emb[:, None, None, :] return emb, _mscale + @lru_cache(maxsize=32) + @internal_api def forward( - self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + self, + max_seq_len: int, + offset: int = 0, + packed_seq: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Tensor: """Forward pass of Yarn Rotary Embedding. @@ -170,11 +175,7 @@ def forward( Tensor: Embeddings after applying Yarn RoPE. """ emb, _mscale = self.get_emb(max_seq_len, offset) - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - # Set CP group to dynamic CP group for CP slicing - cp_group = packed_seq_params.cp_group - else: + if cp_group is None: cp_group = self.cp_group if cp_group is not None and cp_group.size() > 1 and not packed_seq: # slice rotary_pos_emb along sequence dimension diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 8fb11ca0db1..0932f93bf64 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -345,7 +345,10 @@ def _preprocess( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, packed_seq_params=packed_seq_params + rotary_seq_len, + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == 'thd', + cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None, ) elif self.position_embedding_type == 'yarn': if self.training or not self.config.flash_decode: @@ -353,7 +356,10 @@ def _preprocess( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) rotary_pos_emb, _ = self.rotary_pos_emb( - rotary_seq_len, packed_seq_params=packed_seq_params + rotary_seq_len, + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == 'thd', + cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None, ) else: raise NotImplementedError( @@ -363,7 +369,9 @@ def _preprocess( elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: if self.training or not self.config.flash_decode: rotary_pos_emb = self.rotary_pos_emb( - position_ids, self.mrope_section, packed_seq_params=packed_seq_params + position_ids, + self.mrope_section, + cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None, ) else: # Flash decoding uses precomputed cos and sin for RoPE diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py index c7ca94c5aa9..27b5fc87945 100644 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -3,7 +3,7 @@ from collections import deque from functools import lru_cache from math import ceil, log2 -from typing import Any, Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch From 2e473d995908d1afab07bf4371f85031c2be9e7f Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 13:30:40 -0800 Subject: [PATCH 73/81] Update function signature to reflect change --- .../core/models/common/embeddings/rotary_pos_embedding.py | 5 +++-- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index d7908e4b185..3c2648774fe 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -188,7 +188,8 @@ def forward( Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: Tensor: Embeddings after applying RoPE. @@ -314,7 +315,7 @@ def forward( position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: Tensor: Embeddings after applying RoPE. diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 69031a4584c..add3b07c8c2 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -169,7 +169,8 @@ def forward( Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: Tensor: Embeddings after applying Yarn RoPE. From eed957c6fb7d810b1add084e3c4362958d1d99db Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 13:32:59 -0800 Subject: [PATCH 74/81] Formatting --- .../core/models/common/embeddings/rotary_pos_embedding.py | 6 ++++-- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 3 ++- megatron/core/utils.py | 4 +--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 3c2648774fe..87e801add63 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -189,7 +189,8 @@ def forward( max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + Defaults to None. Returns: Tensor: Embeddings after applying RoPE. @@ -315,7 +316,8 @@ def forward( position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + Defaults to None. Returns: Tensor: Embeddings after applying RoPE. diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index add3b07c8c2..3e29cdfe74b 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -170,7 +170,8 @@ def forward( max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + Defaults to None. Returns: Tensor: Embeddings after applying Yarn RoPE. diff --git a/megatron/core/utils.py b/megatron/core/utils.py index c7f7f3d52b3..4f581819bea 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -2171,9 +2171,7 @@ def get_batch_on_this_hybrid_cp_rank( if cp_group is not None and cp_group.size() > 1: # When using hybrid_context_parallel, each sub-sample of a packed sample is # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank( - batch, cp_group=cp_group - ) + batch = get_batch_on_this_cp_rank(batch, cp_group=cp_group) return batch, packed_seq_params From 396ee10f9a999c021814536be1bda22c6d2ab7a3 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 13:38:19 -0800 Subject: [PATCH 75/81] Lint fix --- megatron/core/datasets/data_schedule.py | 2 ++ .../core/models/common/embeddings/rotary_pos_embedding.py | 4 ++-- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index b3f9df64b4b..0f016473b6a 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,7 +1,9 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. from typing import Any, List, Optional + import torch + from megatron.core import parallel_state from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 87e801add63..82deccbfcf3 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -189,7 +189,7 @@ def forward( max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: @@ -316,7 +316,7 @@ def forward( position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 3e29cdfe74b..42b146af6d4 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -170,7 +170,7 @@ def forward( max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. - cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. + cp_group (torch.distributed.ProcessGroup, optional): Context parallel group. Defaults to None. Returns: From ea5d3f2e8319465c05d46060d7422137b466c059 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 15:34:32 -0800 Subject: [PATCH 76/81] Missing import --- megatron/core/models/common/embeddings/rotary_pos_embedding.py | 2 +- .../core/models/common/embeddings/yarn_rotary_pos_embedding.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 82deccbfcf3..05a7e8f60bb 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -25,7 +25,7 @@ apply_rotary_pos_emb, get_pos_emb_on_this_cp_rank, ) -from megatron.core.utils import deprecate_inference_params +from megatron.core.utils import deprecate_inference_params, internal_api logger = logging.getLogger(__name__) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 42b146af6d4..59a9dbf5e8e 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -13,6 +13,7 @@ from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.transformer import TransformerConfig +from megatron.core.utils import internal_api logger = logging.getLogger(__name__) From 465be7ca9f3a53ee359ecef6810143c341a33295 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 15:57:46 -0800 Subject: [PATCH 77/81] Fix import --- megatron/training/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 734a351bba0..bb2dcc99d51 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -103,7 +103,7 @@ from megatron.training.initialize import write_args_to_tensorboard from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank -from megatron.legacy.data.data_samplers import build_pretraining_data_loader +from megatron.training.datasets.data_samplers import build_pretraining_data_loader from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils From 68f7a9fb324a965c9274ce90bf738742b9574dda Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 16:25:47 -0800 Subject: [PATCH 78/81] Fix MLA calls --- .../common/embeddings/yarn_rotary_pos_embedding.py | 13 +++++-------- megatron/core/transformer/multi_latent_attention.py | 10 +++------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index 59a9dbf5e8e..166ef9b41e7 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -186,15 +186,13 @@ def forward( emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb, _mscale - def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq_params=None): + def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): self.max_seq_len_cached = seq_len self.offset_cached = offset self.dtype_cached = dtype - self.packed_seq_cached = ( - packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - ) + self.packed_seq_cached = packed_seq - emb, _mscale = self.forward(seq_len, offset, packed_seq_params) + emb, _mscale = self.forward(seq_len, offset, packed_seq) self.register_buffer( "cos_cached", (emb.cos() * _mscale).to(dtype).contiguous(), persistent=False ) @@ -203,17 +201,16 @@ def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq_params=None): ) def get_cached_cos_sin( - self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq_params=None + self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq=False ): """Get cached cos and sin values.""" - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if ( seq_len > self.max_seq_len_cached or offset != self.offset_cached or dtype != self.dtype_cached or packed_seq != self.packed_seq_cached ): - self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq_params) + self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq) return (self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a6beb556956..59e6ebf6719 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -541,13 +541,11 @@ def get_query_key_value_tensors( rotary_pos_sin = None packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, packed_seq_params=packed_seq_params - ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) else: if self.config.apply_rope_fusion: rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( - rotary_seq_len, dtype=hidden_states.dtype, packed_seq_params=packed_seq_params + rotary_seq_len, dtype=hidden_states.dtype, packed_seq=packed_seq ) rotary_pos_emb = None assert inference_context is None, "Inference with MLA RoPE fusion is not supported" @@ -556,9 +554,7 @@ def get_query_key_value_tensors( and fused_apply_mla_rope_for_kv is not None ), "Fused MLA RoPE apply is not imported successfully" else: - rotary_pos_emb, mscale = self.rotary_pos_emb( - rotary_seq_len, packed_seq_params=packed_seq_params - ) + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: From 1860dbeb5d77bb05ad6926f4a897c0b6be6c2a06 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 17:11:03 -0800 Subject: [PATCH 79/81] Updated Golden_config for Mamba Moe --- tests/unit_tests/models/test_mamba_moe_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index c4cfbc2e56c..7e9987cf2ab 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -257,6 +257,8 @@ "fine_grained_activation_offloading": False, "min_offloaded_tensor_size": 1024 * 1024, "offload_modules": [], + "hybrid_context_parallel": False, + "max_seqlen_per_dp_cp_rank": None, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() From 06d086fe684b2ef4243ed04bb5f3da744666e9f4 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 15 Jan 2026 22:51:32 -0800 Subject: [PATCH 80/81] Fix for pretrain_mamba --- pretrain_mamba.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pretrain_mamba.py b/pretrain_mamba.py index 3b2ec5b47dc..2966f76bea4 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -78,7 +78,6 @@ def get_batch(data_iterator, vp_stage=None): # Support for Hybrid Context Parallel (Unused in this script) local_cp_size = batch.pop('local_cp_size', None) - cu_seqlens = batch['cu_seqlens'] if cu_seqlens is not None: assert ( cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 @@ -86,7 +85,6 @@ def get_batch(data_iterator, vp_stage=None): cu_seqlens = cu_seqlens[0] batch['cu_seqlens'] = cu_seqlens - max_seqlen = batch['max_seqlen'] assert max_seqlen.dim() == 1 # TODO(duncan): can this be kept as a 0-D tensor? batch['max_seqlen'] = int(max_seqlen[0].item()) From aa99bcfbd093d695bae87f013a9be59bc16eab2c Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 16 Jan 2026 10:15:22 -0800 Subject: [PATCH 81/81] Fix error --- pretrain_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pretrain_mamba.py b/pretrain_mamba.py index 2966f76bea4..56f52631e79 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -71,10 +71,9 @@ def get_batch(data_iterator, vp_stage=None): batch = get_batch_on_this_tp_rank(data_iterator) - # Support for Packed Sequence (Unused in this script) - cu_seqlens = batch.pop('cu_seqlens', None) + cu_seqlens = batch['cu_seqlens'] + # Unused at the moment cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) - max_seqlen = batch.pop('max_seqlen', None) # Support for Hybrid Context Parallel (Unused in this script) local_cp_size = batch.pop('local_cp_size', None) @@ -85,6 +84,7 @@ def get_batch(data_iterator, vp_stage=None): cu_seqlens = cu_seqlens[0] batch['cu_seqlens'] = cu_seqlens + max_seqlen = batch['max_seqlen'] assert max_seqlen.dim() == 1 # TODO(duncan): can this be kept as a 0-D tensor? batch['max_seqlen'] = int(max_seqlen[0].item())