diff --git a/examples/hybrid_cp/run_hybrid_cp.sh b/examples/hybrid_cp/run_hybrid_cp.sh new file mode 100755 index 00000000000..7f9d289e2e7 --- /dev/null +++ b/examples/hybrid_cp/run_hybrid_cp.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export TOKENIZERS_PARALLELISM="false" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +MCORE_PATH="../" +OUTPUT_BASE="./output" +SEQ_LEN=16384 + +DYNAMIC_CP_ARGS=" \ + --dynamic-context-parallel \ + --sequence-packing \ + --calculate-per-token-loss \ + --max-seqlen-per-dp-cp-rank 4096 \ +" + +ARGS=" \ + --sft \ + --legacy-tokenizer \ + --tokenizer-type NullTokenizer \ + --vocab-size 131072 \ + --mock-data \ + --sft-mock-dataset-config-json {\"mode\":\"distribution\",\"type\":\"lognormal\",\"min_seq_len\":1024,\"max_seq_len\":16384,\"mean_seq_len\":8192,\"lognormal_sigma\":1.1} \ + --use-distributed-optimizer \ + --disable-bias-linear \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --rerun-mode disabled \ + --num-layers 4 \ + --hidden-size 2048 \ + --ffn-hidden-size 8192 \ + --add-qkv-bias \ + --num-attention-heads 16 \ + --num-workers 8 \ + --exit-duration-in-mins 230 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples 100000 \ + --lr-warmup-samples 20000 \ + --micro-batch-size 4 \ + --global-batch-size 256 \ + --lr 2e-5 \ + --min-lr 0.0 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 10 \ + --eval-interval 999999 \ + --save-interval 1000 \ + --use-mcore-models \ + --no-create-attention-mask-in-dataloader \ + --no-mmap-bin-files \ + --clip-grad 1.0 \ + --weight-decay 0.05 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --init-method-std 0.014 \ + --bf16 \ + --distributed-timeout-minutes 60 \ + --attention-backend flash \ + --disable-gloo-process-groups \ + --use-dist-ckpt \ +" + +torchrun --nproc_per_node 8 ${MCORE_PATH}/pretrain_gpt.py ${ARGS} ${DYNAMIC_CP_ARGS} diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..1184e9cc4c4 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,23 +1,51 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, List, Optional +import enum +from collections import deque +from functools import lru_cache +from math import ceil, log2 +from typing import Callable, Deque, Dict, List, Optional, Tuple, Type, Union +import numpy as np import torch from megatron.core import parallel_state -from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.core.utils import is_te_min_version +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import -class HybridCPDataLoaderWrapper: + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None + + +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + # custom dynamic-cp scheduler, schedule in samplers, only need to pack + EMPTY_PACKING = "empty_scheduler_with_packing" + # custom dynamic-cp scheduler, schedule in samplers and pack in collate_fn + EMPTY_NO_PACKING = "empty_scheduler_no_packing" + NAIVE_SEQUENCE_PACKING = "naive_sequence_packing" + DEFAULT_DYNAMIC_CP = "default_dynamic_cp" + + +def wrap_dataloader( + data_iterator, + config, + scheduler_type: Union[PackingScheduler, str], + pg_collection: Optional[ProcessGroupCollection] = None, +): """ - A wrapper class that wraps around an existing data_iterator. - For every __next__ call, - 1. Each DP rank pulls a batch of packed samples. - 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. - 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. - 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. - 5. Returns the assigned sub-samples to this rank. + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. Args: data_iterator: The original data_iterator to wrap around @@ -25,34 +53,14 @@ class HybridCPDataLoaderWrapper: dp_cp_group: Data parallel context parallel group. """ - def __init__( - self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None - ): - self.data_iterator = data_iterator - self.config = config - if pg_collection is None: - self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - self.dp_group = parallel_state.get_data_parallel_group() - self.tp_group = parallel_state.get_tensor_model_parallel_group() - else: - self.dp_cp_group = pg_collection.dp_cp - self.dp_group = pg_collection.dp - self.tp_group = pg_collection.tp - assert ( - self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None - ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" - - self.cp_balancing_scheduler = BalancedCPScheduler( - max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group - ) - - self.total_hdp_gpus = self.dp_cp_group.size() - - def __iter__(self): - """Return self as an iterator.""" - return self + scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.DEFAULT_DYNAMIC_CP: DefaultDynamicCPscheduler, + PackingScheduler.NAIVE_SEQUENCE_PACKING: NaiveSequencePackingScheduler, + PackingScheduler.EMPTY_PACKING: EmptyPackingScheduler, + PackingScheduler.EMPTY_NO_PACKING: EmptyNoPackingScheduler, + } - def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: + def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: """ Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch @@ -63,8 +71,8 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: """ # Collect the number of subsamples from all ranks local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() - dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] - torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) @@ -83,11 +91,9 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: # Gather the subsample_seqlens from all ranks seqlens_gathered = [ - torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) + torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size()) ] - torch.distributed.all_gather( - seqlens_gathered, subsample_seqlens_padded, group=self.dp_group - ) + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) # Trim each seqlens_gathered to the length of the correct sample for dp_rank, seqlen in enumerate(seqlens_gathered): @@ -102,7 +108,7 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: return seqlens_gathered, offsets - def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): + def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): """ Calculates the global ID for each subsample. @@ -112,7 +118,7 @@ def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered) global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. global_ids_this_rank: list of global IDs locally present on this rank. """ - dp_rank = self.dp_group.rank() + dp_rank = dp_group.rank() global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() # Create a list of (global_id, seqlen) tuples for scheduling global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] @@ -123,18 +129,25 @@ def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered) return global_id_seqlens, global_ids_this_rank - def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: + def _gid_to_src_rank(gid: int, offsets: List[int], dp_group, tp_group, dp_cp_group) -> int: dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) # Since the torch.distributed.get_process_group_ranks # provides the global rank, we need to consider TP hdp_rank = ( - torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] - // self.tp_group.size() - ) + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() return hdp_rank - def reroute_samples_to_hdp_ranks( - self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + def _reroute_samples_to_hdp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_hdp_gpus, ): """ Reroutes the sub-samples to the correct rank after scheduling. @@ -145,22 +158,23 @@ def reroute_samples_to_hdp_ranks( to transfer data between matching CP ranks. """ gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - hdp_rank = self.dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) + hdp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) # Here we actually want to get the DP group's rank within the HDP group, # we need to consider TP - dp_ranks = [r // self.tp_group.size() for r in dp_ranks] + # tp-cp-ep-dp-pp + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] data_keys = batch[0].keys() # Create the send plan - combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_hdp_gpus)] - for d in range(self.total_hdp_gpus): + for d in range(total_hdp_gpus): for sample_id_group in sample_id_groups: combined_sample_id_groups[d].extend(sample_id_group[d]) - for dest_rank in range(self.total_hdp_gpus): + for dest_rank in range(total_hdp_gpus): combined_sample_id_groups[dest_rank].sort() # Filter out samples that are not present on this rank @@ -170,38 +184,37 @@ def reroute_samples_to_hdp_ranks( for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank ] - # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + # send_counts = [len(combined_sample_id_groups[d]) for d in range(total_hdp_gpus)] - send_lens_split = [0] * self.total_hdp_gpus - for dest_rank in range(self.total_hdp_gpus): + send_num_split = [0] * total_hdp_gpus + send_lens_split = [0] * total_hdp_gpus + for dest_rank in range(total_hdp_gpus): if dest_rank in dp_ranks: - send_lens_split[dest_rank] = sum( - [ - global_id_seqlens[gid][1] - for gid in combined_sample_id_groups[dest_rank] - if gid in global_ids_this_rank - ] - ) + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) else: # We only need to share local data with DP ranks that have different data. send_lens_split[dest_rank] = 0 # Create the recv plan - recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] + recv_sample_id_groups = [[] for _ in range(total_hdp_gpus)] for gid in combined_sample_id_groups[hdp_rank]: - src_rank = self._gid_to_src_rank(gid, offsets) + src_rank = _gid_to_src_rank(gid, offsets, dp_group, tp_group, dp_cp_group) recv_sample_id_groups[src_rank].append(gid) - recv_lens_split = [0] * self.total_hdp_gpus - for src_rank in range(self.total_hdp_gpus): + recv_lens_split = [0] * total_hdp_gpus + for src_rank in range(total_hdp_gpus): recv_lens_split[src_rank] = sum( [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] ) - recv_ids_sorted = [ - gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] - ] - recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + recv_ids_sorted = [gid for d in range(total_hdp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_hdp_gpus)] recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] @@ -209,31 +222,42 @@ def _pack_sample_by_key(key: str) -> torch.Tensor: flattened_tensors = [] for gid in send_ids_sorted: t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) - flattened_tensors.append(t) + # flattened_tensors.append(t) + flattened_tensors.append(t.reshape(-1)) return ( torch.cat(flattened_tensors, dim=0) if flattened_tensors - else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) ) def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): cursor = 0 for i, gid in enumerate(recv_ids_sorted): - sample_len = global_id_seqlens[gid][1] + sample_len = ( + 1 + if key in ["original_seq_len", "padded_seq_len"] + else global_id_seqlens[gid][1] + ) recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] cursor += sample_len for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) recv_tensor = torch.empty( - sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype ) torch.distributed.all_to_all_single( output=recv_tensor, input=send_tensor, - output_split_sizes=recv_lens_split, - input_split_sizes=send_lens_split, - group=self.dp_cp_group, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, ) _unpack_sample_by_key(key, recv_tensor) @@ -242,7 +266,7 @@ def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): } return recv_sample_with_id - def unpack_batch(self, batch): + def _unpack_batch(batch): """ Unpacks the packed samples into a list of sub-samples. Since each sub-sample may be routed to different DPxCP ranks, @@ -251,51 +275,1626 @@ def unpack_batch(self, batch): """ batch_unpacked = [] for sample in batch: - for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): - sub_sample_dict = {} - start_idx = sample["cu_seqlens"][sub_sample] - end_idx = sample["cu_seqlens"][sub_sample + 1] - if end_idx - start_idx == 0: + sample_dict = {} + for key in sample.keys(): + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: continue - for key in sample.keys(): - if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: - continue - sub_sample_dict[key] = sample[key][start_idx:end_idx] - batch_unpacked.append(sub_sample_dict) + sample_dict[key] = sample[key] + batch_unpacked.append(sample_dict) return batch_unpacked - def __next__(self) -> Any: + def _broadcast_to_tp_group(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_tensor_model_parallel_src_rank(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + def _broadcast_to_pp_group(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_pipeline_model_parallel_first_rank(), + group=parallel_state.get_pipeline_model_parallel_group(), + ) + + def _pack_sequences( + samples: List, + local_cp_size: torch.Tensor, + padded_lengths: torch.Tensor, + original_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + if local_cp_size is not None: + # Accept either a Python int or a CUDA scalar tensor; keep as a scalar tensor on GPU. + new_sample["local_cp_size"] = local_cp_size.to(device=dev, dtype=torch.int32) + + padded_lengths = padded_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + # create cu_seqlens without padding + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + def _build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], + scheduler_type: PackingScheduler, + local_cp_sizes_gpu: Optional[torch.Tensor], + ) -> List[Dict[str, torch.Tensor]]: """ - Get the next item from the dataset, pull scheduling metadata and return it. + Build packed samples for each microbatch given a pre-built list of `samples` per microbatch. + + Args: + grouped_samples: List of length `num_microbatches`. Each element is the `samples` list + (list[sample]) for that microbatch, where `sample` is the dict returned by + `dataset.__getitem__`. + scheduler_type: packing scheduler. + local_cp_sizes_gpu: CUDA int32 tensor of shape [num_micro_batches] + when DEFAULT_DYNAMIC_CP, otherwise None. + + Returns: + new_samples: list of packed samples (dicts) length == num_micro_batches. """ - if self.data_iterator is None: - # TP0 reads from data_iterator, others receive via broadcast. - return None, None + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + + lp = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lo = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP: + assert local_cp_sizes_gpu is not None + partner_cp_arg = local_cp_sizes_gpu[i] + else: + partner_cp_arg = None + + new_sample = _pack_sequences(samples, partner_cp_arg, lp, lo) + new_samples.append(new_sample) + + return new_samples + + # Convert string to enum if needed + if isinstance(scheduler_type, str): + try: + scheduler_type = PackingScheduler[scheduler_type.upper()] + except KeyError: + available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler]) + raise ValueError( + f"Unknown packing scheduler: {scheduler_type}. " + f"Available schedulers: {available_scheduler}" + ) + + if scheduler_type not in scheduler_map: + available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler]) + raise ValueError( + f"Unknown scheduler: {scheduler}. " f"Available schedulers: {available_scheduler}" + ) + + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None and dp_group is not None and tp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using dynamic context parallel" + + total_hdp_gpus = dp_cp_group.size() + dev = torch.cuda.current_device() + + scheduler = scheduler_map[scheduler_type]( + config.max_seqlen_per_dp_cp_rank, + parallel_state.get_context_parallel_world_size(), + parallel_state.get_data_parallel_world_size(), + # When VPP is enabled, align num_micro_batches to this multiple. + ( + None + if (config.virtual_pipeline_model_parallel_size is None or + config.virtual_pipeline_model_parallel_size == 1) + else config.microbatch_group_size_per_vp_stage + ), + config.dynamic_context_parallel, + ) + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + if pp_group.rank() == pp_group.size() - 1: + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + data_iterator = data_iterator[-1] else: - batch = next(self.data_iterator) - subsample_seqlens = [] - for sample in batch: - subsample_seqlens.extend( - [ - int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) - for i in range(0, sample["cu_seqlens"].shape[0] - 1) + data_iterator = data_iterator[0] + + if data_iterator is not None: + batch = next(data_iterator) + + # check if the batch contains all the required keys + assert scheduler.check_require_sample_keys(batch), "Batch missing required keys" + # indicates TP rank 0, with PP stage 0 or -1. + if scheduler_type is PackingScheduler.EMPTY_PACKING: + # EMPTY scheduler does not schedule the data, + # just packing sequences + + # Here, next(data_iterator) returns multiple samples, i.e., a list[sample]. + # Each sample is the value returned by dataset.__getitem__ (see + # `megatron/training/datasets/sft_dataset.py` for the concrete sample fields). + # This indicates that, according to the scheduler's result, + # these samples (sequences) should be packed together. + num_micro_batches = batch[0]["num_micro_batches_left"] + 1 + + batch_all = [batch] + [next(data_iterator) for _ in range(num_micro_batches - 1)] + + num_total_tokens = 0 + sequence_square_sum = 0 + + # pack sequences in the same group and create a new data iterator + new_samples = [] + for samples in batch_all: + # In EMPTY scheduler, scheduler has already selected the grouping and + # provides `local_cp_size` for each packed group. + local_cp_size = samples[0]["local_cp_size"] + # Convert local_cp_size to a python int for FLOPs accounting + local_cp_size_int = ( + int(local_cp_size.item()) + if isinstance(local_cp_size, torch.Tensor) + else int(local_cp_size) + ) + + # Build per-sub-sample length tensors on GPU so `_pack_sequences` can compute + # cu_seqlens and cu_seqlens_padded via GPU cumsum. + padded_lengths = torch.cat( + [s["padded_seq_len"].reshape(-1) for s in samples], dim=0 + ) + original_lengths = torch.cat( + [s["original_seq_len"].reshape(-1) for s in samples], dim=0 + ) + + new_sample = _pack_sequences( + samples, local_cp_size, padded_lengths, original_lengths + ) + new_samples.append(new_sample) + for sample in samples: + n = sample["tokens"].numel() + # tokens.numel() is a python int (no D2H). local_cp_size_int is python int. + num_total_tokens += n / local_cp_size_int + sequence_square_sum += n**2 / local_cp_size_int + # allreduce to get the total number of microbatches + flops_info_to_broadcast_this_hdp_group = torch.tensor( + [num_total_tokens, sequence_square_sum], dtype=torch.float32, device=dev + ) + torch.distributed.all_reduce(flops_info_to_broadcast_this_hdp_group, group=dp_cp_group) + flops_info_cpu = flops_info_to_broadcast_this_hdp_group.cpu().numpy() + num_total_tokens_this_global_batch = flops_info_cpu[0] + sequence_square_sum_this_global_batch = flops_info_cpu[1] + + elif scheduler_type is PackingScheduler.EMPTY_NO_PACKING: + # EMPTY_NO_PACKING scheduler does not schedule the data, + # sequences are already packed, we just need to return the batch + num_micro_batches = batch["num_micro_batches_left"] + 1 + num_total_tokens_this_global_batch = batch["num_total_tokens_this_global_batch"] + sequence_square_sum_this_global_batch = batch["sequence_square_sum_this_global_batch"] + new_samples = [batch] + [next(data_iterator) for _ in range(num_micro_batches - 1)] + + elif ( + scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP + or scheduler_type is PackingScheduler.NAIVE_SEQUENCE_PACKING + ): + + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend([sample["tokens"].numel()]) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] + + seqlens_gathered, offsets = _get_global_seqlens(subsample_seqlens, dp_group) + + global_id_seqlens, global_ids_this_rank = _get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group + ) + + sample_id_groups = scheduler.get_groups_and_subsamples(global_id_seqlens) + + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len( + global_id_seqlens + ), f"set_gbs length: {len(set_gbs)} \ + != global_ids_this_rank length: {len(global_id_seqlens)}" + + batch = _unpack_batch(batch) + samples_this_rank_with_id = _reroute_samples_to_hdp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_hdp_gpus, + ) + batch, sample_id_groups = samples_this_rank_with_id, sample_id_groups + + hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + num_micro_batches = len(sample_id_groups) + + # local_cp_sizes_gpu is computed outside and passed in for DEFAULT_DYNAMIC_CP. + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP: + # One H2D total + local_cp_sizes_cpu: List[int] = [] + for i in range(num_micro_batches): + sample_ids_this_group = sample_id_groups[i][hdp_rank] + local_cp_sizes_cpu.append( + len( + [ + 1 + for sample_ids in sample_id_groups[i] + if sample_ids_this_group[0] in sample_ids + ] + ) + ) + local_cp_sizes_gpu = torch.tensor(local_cp_sizes_cpu, dtype=torch.int32, device=dev) + else: + local_cp_sizes_gpu = None + + grouped_samples = [ + [batch[sub_sample_id] for sub_sample_id in sample_id_groups[i][hdp_rank]] + for i in range(num_micro_batches) + ] + + new_samples = _build_packed_microbatches( + grouped_samples=grouped_samples, + scheduler_type=scheduler_type, + local_cp_sizes_gpu=local_cp_sizes_gpu, + ) + + # calculate this two values for tflops calculation + num_total_tokens_this_global_batch = float(sum(seqlens_gathered)) + sequence_square_sum_this_global_batch = float( + sum(seqlen**2 for seqlen in seqlens_gathered) + ) + + # broadcast num_micro_batches, num_total_tokens_this_global_batch, + # sequence_square_sum_this_global_batch, and packed_seq_params to PP group + if pp_group.size() > 2 and tp_group.rank() == 0: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append( + sample["local_cp_size"].unsqueeze(0) + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP + else torch.tensor([-1], dtype=torch.float32).cuda() + ) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast_this_pp_group = torch.cat(tensor_list, dim=0).to( + device=dev, dtype=torch.float32 + ) + info_length_tensor = torch.tensor( + info_to_broadcast_this_pp_group.shape[0], dtype=torch.int32 + ).cuda() + _broadcast_to_pp_group(info_length_tensor) + _broadcast_to_pp_group(info_to_broadcast_this_pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + _broadcast_to_pp_group(info_length_tensor) + info_to_broadcast_this_pp_group = torch.empty( + info_length_tensor.item(), dtype=torch.float32 + ).cuda() + _broadcast_to_pp_group(info_to_broadcast_this_pp_group) + if pp_group.rank() != pp_group.size() - 1: + info_numpy = info_to_broadcast_this_pp_group.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + num_total_tokens_this_global_batch = info_numpy[1] + sequence_square_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast_this_pp_group[3 : 3 + num_micro_batches] + is_dynamic_cp = int(info_numpy[3 + num_micro_batches]) != -1 + local_cp_sizes = info_to_broadcast_this_pp_group[ + 3 + num_micro_batches : 3 + 2 * num_micro_batches ] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append( + info_to_broadcast_this_pp_group[indices[i * 2] : indices[i * 2 + 1]] + ) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append( + info_to_broadcast_this_pp_group[indices[i * 2 + 1] :] + ) + else: + cu_seqlens_padded_list.append( + info_to_broadcast_this_pp_group[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + if is_dynamic_cp: + new_sample["local_cp_size"] = local_cp_sizes[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + if tp_group.size() > 1: + # TODO(tailaim): This triggers H2D/D2H transfers. + # Ideally, we should perform the communication over a CPU process + if tp_group.rank() == 0: + info_to_broadcast_this_tpgroup = torch.tensor( + [ + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ], + dtype=torch.float32, + device=dev, ) - subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() - subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] + _broadcast_to_tp_group(info_to_broadcast_this_tpgroup) + else: + info_to_broadcast_this_tpgroup = torch.tensor( + [0, 0, 0], dtype=torch.float32, device=dev + ) + _broadcast_to_tp_group(info_to_broadcast_this_tpgroup) + info_numpy = info_to_broadcast_this_tpgroup.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + num_total_tokens_this_global_batch = info_numpy[1] + sequence_square_sum_this_global_batch = info_numpy[2] + + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + new_samples_for_other_ppstage = [] + for sample in new_samples: + new_sample_for_other_ppstage = {} + new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] + new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] + new_sample_for_other_ppstage["cu_seqlens_padded"] = sample["cu_seqlens_padded"] + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP: + new_sample_for_other_ppstage["local_cp_size"] = sample["local_cp_size"] + new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + new_data_iterator = [RerunDataIterator(iter(new_samples)) for _ in range(vpp_size)] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) + + +class BaseScheduler: + """ + Base class for sequence packing schedulers. + """ + + def __init__( + self, + max_seqlen_per_dp_cp_rank: Optional[int], + cp_size: int, + dp_size: int, + microbatch_group_size_per_vp_stage: Optional[int], + dynamic_context_parallel: bool = False, + ): + self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank + self.cp_size = cp_size + self.dp_size = dp_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage + self.dynamic_context_parallel = dynamic_context_parallel + + def check_require_sample_keys(self, batch: List[Dict]): + """ + Required per-(sub)sample fields expected by the scheduler. + """ + raise NotImplementedError + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This scheduler simply packs sequences in their original order + until reaching the max sequence length. + It does not reorder sequences nor perform any load balancing. + """ + raise NotImplementedError + - seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) +class EmptyPackingScheduler(BaseScheduler): + """ + This scheduler only packs sequences in their original order + and does not perform any load balancing. + """ + + def __init__( + self, + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel: bool = False, + ): + super().__init__( + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel=dynamic_context_parallel, + ) + + def check_require_sample_keys(self, batch: List[Dict]): + """ + Required per-(sub)sample fields expected by the default dynamic CP + - tokens[torch.Tensor]:1D tensor of input token ids for the (sub)sequence + (typically shape [padded_seq_len], dtype int64). + - labels[torch.Tensor]: 1D tensor of target token ids aligned with `tokens` for next-token + prediction (typically shape [padded_seq_len], dtype int64). + - loss_mask[torch.Tensor]: 1D tensor mask indicating which positions contribute to the + loss (typically shape [padded_seq_len], dtype float); must already reflect + any padding/EOD masking policy. + - position_ids[torch.Tensor]: 1D tensor of positional indices used by the model + (typically shape [padded_seq_len], dtype int64). + - original_seq_len[torch.Tensor]: Scalar int32 tensor length of the unpadded (effective) + sequence, used to build `cu_seqlens` for variable-length attention. + - padded_seq_len[torch.Tensor]: Scalar int32 tensor length after padding/truncation, + used to build `cu_seqlens_padded` and for max_seqlen computation. + - local_cp_size[torch.Tensor]: Scalar int32 tensor of the partner CP size, used to build + `local_cp_size` for Dynamic-CP. + - num_micro_batches_left[int]: number of microbatches left to be fetched. + """ + required_keys = [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", + "padded_seq_len", + "local_cp_size", + "num_micro_batches_left", + ] - global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( - subsample_seqlens.shape[0], offsets, seqlens_gathered + # - Each `next(data_iterator)` returns a microbatch worth of samples. + # - The returned `batch` is a Python `list` of per-sample dicts + # (i.e., `List[Dict[str, Tensor]]`). + # - `batch[0]["num_micro_batches_left"]` indicates how many additional microbatches + # should be fetched afterwards for the current step (so the caller will fetch + # `num_micro_batches_left` more times, in addition to the first fetch). + + for key in required_keys: + if key not in batch[0]: + return False + if "local_cp_size" in batch[0]: + assert ( + self.dynamic_context_parallel + ), "local_cp_size is only supported when using dynamic context parallel" + return True + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This scheduler only packs sequences in their original order + and does not perform any load balancing. + """ + pass + + +class EmptyNoPackingScheduler(BaseScheduler): + """ + It does not pack sequences, it only returns the original batch. + """ + + def __init__( + self, + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel: bool = False, + ): + super().__init__( + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel=dynamic_context_parallel, + ) + + def check_require_sample_keys(self, batch: List[Dict]): + """ + Required per-(sub)sample fields expected by the default dynamic CP + - tokens[torch.Tensor]:1D tensor of input token ids for the (sub)sequence + (typically shape [padded_seq_len], dtype int64). + - labels[torch.Tensor]: 1D tensor of target token ids aligned with `tokens` for next-token + prediction (typically shape [padded_seq_len], dtype int64). + - loss_mask[torch.Tensor]: 1D tensor mask indicating which positions contribute to the + loss (typically shape [padded_seq_len], dtype float); must already reflect any + padding/EOD masking policy. + - position_ids[torch.Tensor]: 1D tensor of positional indices used by the model + (typically shape [padded_seq_len], dtype int64). + - cu_seqlens[torch.Tensor]: 1D int32 tensor of cumulative (prefix-sum)*original sequence + lengths, shape [num_seqs + 1], with cu_seqlens[0] = 0. Used for variable-length + attention over unpadded token streams. + - cu_seqlens_padded[torch.Tensor]: 1D int32 tensor of cumulative (prefix-sum) padded + sequence lengths, shape [num_seqs + 1], with cu_seqlens_padded[0] = 0. + Used for packed layouts where each sequence occupies `padded_seq_len` tokens. + - max_seqlen[torch.Tensor]: Scalar int32 tensor of the maximum sequence length in + the microbatch (typically the max of padded lengths). + - local_cp_size[torch.Tensor]: Scalar int32 tensor of the local_cp_size CP size, used to + build `local_cp_size`. + - num_micro_batches_left[int]: number of microbatches left to be fetched. + - num_total_tokens_this_global_batch[float]: total number of tokens in the global batch. + Used for tflops calculation. + - sequence_square_sum_this_global_batch[float]: sum of the squares of the sequence lengths + in the global batch. Used for tflops calculation. + """ + required_keys = [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "cu_seqlens", + "cu_seqlens_padded", + "max_seqlen", + "local_cp_size", + "num_micro_batches_left", + "num_total_tokens_this_global_batch", + "sequence_square_sum_this_global_batch", + ] + + # Each call returns the packed sequences for one microbatch; the data_iterator will fetch + # num_micro_batches_left more times. + + for key in required_keys: + if key not in batch[0]: + return False + + if "local_cp_size" in batch[0]: + assert ( + self.dynamic_context_parallel + ), "local_cp_size is only supported when using dynamic context parallel" + + return True + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This scheduler only packs sequences in their original order + and does not perform any load balancing. + """ + pass + + +class NaiveSequencePackingScheduler(BaseScheduler): + """ + This scheduler simply packs sequences in their original order + until reaching the max sequence length. + It does not reorder sequences nor perform any load balancing. + """ + + def __init__( + self, + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel: bool = False, + ): + super().__init__( + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel=dynamic_context_parallel, ) + self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + + def check_require_sample_keys(self, batch: List[Dict]): + """ + Required per-(sub)sample fields expected by the default dynamic CP + - tokens[torch.Tensor]:1D tensor of input token ids for the (sub)sequence + (typically shape [padded_seq_len], dtype int64). + - labels[torch.Tensor]: 1D tensor of target token ids aligned with `tokens` for next-token + prediction (typically shape [padded_seq_len], dtype int64). + - loss_mask[torch.Tensor]: 1D tensor mask indicating which positions contribute to the + loss (typically shape [padded_seq_len], dtype float); must already reflect any + padding/EOD masking policy. + - position_ids[torch.Tensor]: 1D tensor of positional indices used by the model + (typically shape [padded_seq_len], dtype int64). + - original_seq_len[torch.Tensor]: Scalar int32 tensor length of the unpadded (effective) + sequence, used to build `cu_seqlens` for variable-length attention. + - padded_seq_len[torch.Tensor]: Scalar int32 tensor length after padding/truncation, + used to build `cu_seqlens_padded` and for max_seqlen computation. + """ + required_keys = [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", + "padded_seq_len", + ] + # data_iterator returns all samples at once; + # we only fetch it once, rather than iterating num_micro_batches times. + for key in required_keys: + if key not in batch[0]: + return False + return True + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This scheduler simply packs sequences in their original order + until reaching the max sequence length. + It does not reorder sequences nor perform any load balancing. + """ + groups = [] + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) - groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( - global_id_seqlens, self.config + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + + # when enabling vpp, we want the number of packed sequences to be + # multiple of dp_size * microbatch_group_size_per_vp_stage + multiple = self.dp_size * ( + self.microbatch_group_size_per_vp_stage + if self.microbatch_group_size_per_vp_stage is not None + else 1 ) + if num_packed_sequence % multiple != 0: + remainder = num_packed_sequence % multiple + num_to_move = multiple - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i > 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return sample_id_groups + + +class DefaultDynamicCPscheduler(BaseScheduler): + """ + This class provides the functionality to form groups of sub-samples + such that all DPxCP ranks have a roughly balanced workload in the group. + """ - batch = self.unpack_batch(batch) - samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( - batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + def __init__( + self, + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel: bool = False, + ): + super().__init__( + max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + microbatch_group_size_per_vp_stage, + dynamic_context_parallel=dynamic_context_parallel, ) - return samples_this_rank_with_id, sample_id_groups + self.max_seq_len_per_rank = self.max_seqlen_per_dp_cp_rank + self.total_hdp_gpus = self.dp_size * self.cp_size + + def check_require_sample_keys(self, batch: List[Dict]): + """ + Required per-(sub)sample fields expected by the default dynamic CP + - tokens[torch.Tensor]:1D tensor of input token ids for the (sub)sequence + (typically shape [padded_seq_len], dtype int64). + - labels[torch.Tensor]: 1D tensor of target token ids aligned with `tokens` for next-token + prediction (typically shape [padded_seq_len], dtype int64). + - loss_mask[torch.Tensor]: 1D tensor mask indicating which positions contribute to the + loss (typically shape [padded_seq_len], dtype float); must already reflect any + padding/EOD masking policy. + - position_ids[torch.Tensor]: 1D tensor of positional indices used by the model + (typically shape [padded_seq_len], dtype int64). + - original_seq_len[torch.Tensor]: Scalar int32 tensor length of the unpadded (effective) + sequence, used to build `cu_seqlens` for variable-length attention. + - padded_seq_len[torch.Tensor]: Scalar int32 tensor length after padding/truncation, + used to build `cu_seqlens_padded` and for max_seqlen computation. + """ + required_keys = [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", + "padded_seq_len", + ] + + # data_iterator returns all samples at once; + # we only fetch it once, rather than iterating num_micro_batches times. + for key in required_keys: + if key not in batch[0]: + #debugmtl + print(f"key {key} not in batch[0]: {batch[0]}") + return False + return True + + @lru_cache(maxsize=128) + def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: workload of a sub-sample + """ + if cp_size is None: + cp_size = self.gpus_needed(seq_length) + return (seq_length * seq_length) / cp_size + + @lru_cache(maxsize=128) + def gpus_needed(self, seq_len: int) -> int: + """ + Calculates the number of GPUs needed for a given sequence length + and max sequence length per CP rank. + This is used to determine the CP size of a sub-sample. + + The number is rounded up to the next power of 2 to match the available + dynamic context parallel process group sizes. + """ + return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) + + def make_buckets_equal( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + ) -> List[Deque]: + """ + Makes as many buckets as unique CP sizes needed. + This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. + """ + # Extract just the sequence lengths for determining k + seqlens = [seq_len for _, seq_len in sample_seqlens] + + # Determine k based on unique GPU categories needed + k = len({self.gpus_needed(L) for L in seqlens}) + + # Create a work target for each bucket + # This is the total work divided by the number of buckets + work = [] + for _, s in sample_seqlens: + cp_size = self.gpus_needed(s) + work.append(compute_estimator(s, cp_size)) + total_work = sum(work) + target = total_work / k + buckets, cur, cur_work = [], [], 0.0 + remaining_work = total_work + remaining_k = k + + for i, (sample_id, seq_len) in enumerate(sample_seqlens): + work = compute_estimator(seq_len) + projected = cur_work + work + + # Check if we should close this bucket + if cur and ( + projected > target * 1.1 # Too much work + or len(sample_seqlens) - i <= remaining_k - len(buckets) + ): # Need to save sequences for remaining buckets + buckets.append(deque(cur)) + cur, cur_work = [], 0.0 + remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) + remaining_k -= 1 + + cur.append((sample_id, seq_len)) + cur_work += work + + if cur: + buckets.append(deque(cur)) + + return buckets + + def next_hdp_group( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + total_gpus: int, + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance + ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + """ + Given a list of (sample_id, sequence_length) tuples, this function aims to assign + sequences in a group such that all GPUs in the DPxCP group have a roughly balanced + workload. Once each group is roughly balanced, we exit and return the + group and the leftover sequences. + + The function performs the following passes in order to form a balanced microbatch: + 1. We create buckets of sequences that are roughly balanced. + We try to create as many buckets as possible CP sizes. + 2. Given a bucket has sequences available, we assign the sample + a. To a new set of GPUs if there are enough free GPUs. + b. To an existing set of GPUs with the lowest load. + 3. We check if the group is balanced whenever we need to move onto a new CP size + in the same set of GPUs. + 4. We trim the group if removing the last added sequence helps improve balance. + 5. If we run out of sequences to assign and there are empty GPUs, + we redistribute work to empty GPUs by recursively increasing the CP size of a + sample until no empty GPUs are left. + + #TODO: Add clarification on when we check for balance. What does prev_needed do? + + Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). + """ + if not sample_seqlens: + return ( + [[] for _ in range(total_gpus)], + [], + [0.0 for _ in range(total_gpus)], + [[] for _ in range(total_gpus)], + ) + + # Get buckets of sequences with balanced work + buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) + + # Initialize tracking structures + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] + sample_ids_per_gpu = [[] for _ in range(total_gpus)] + # gid : seq_len + packing_sequence_len = {} + + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 + + pp_cursor = 0 + prev_needed = None + check_balance = False + + while buckets: + # ---- Step 1 – pick the next sequence we COULD place ------------------ + sample_seq_tuple = bucket_idx = None + needed = None + + scan_order = ( + range(len(buckets)) + if strategy == "dp" + else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] + ) + + for idx in scan_order: + if not buckets[idx]: + continue + cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) + cand_seq_len = cand_tuple[1] + needed = self.gpus_needed(cand_seq_len) + + # (a) Do we have an *existing* group of size `needed`? + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + + # (b) Or enough completely free GPUs to start a new group? + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if candidate_gids or len(free_ranks) >= needed: + sample_seq_tuple, bucket_idx = cand_tuple, idx + break + + # No place to put any remaining sequence – finish this micro‑batch + if sample_seq_tuple is None: + break + + # TODO[pmannan]: PP not yet supported. Add PP scheduling. + if strategy == "pp": + pp_cursor = (bucket_idx + 1) % len(buckets) + + sample_id, seq_len = sample_seq_tuple + needed = self.gpus_needed(seq_len) + if prev_needed is None: + prev_needed = needed + + # (a) Existing groups of exactly this size + candidate_gids = [ + gid + for gid, sz in group_size.items() + if sz == needed + and packing_sequence_len[gid] + seq_len / needed <= self.max_seq_len_per_rank + ] + if candidate_gids: + best_gid, best_load = min( + ( + (gid, max(exec_times[r] for r in group_members[gid])) + for gid in candidate_gids + ), + key=lambda t: t[1], + ) + else: + best_gid, best_load = None, float("inf") + + # (b) Hypothetical **new** group from completely free GPUs + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if len(free_ranks) >= needed: + free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) + new_members = free_sorted[:needed] + new_load = exec_times[new_members[-1]] + + if new_load < best_load: + best_gid = None + chosen_members = new_members + else: + chosen_members = group_members[best_gid] + else: + if best_gid is None: + break + chosen_members = group_members[best_gid] + + # ---- Step 2 – if we decided to create a fresh group ---------------- + if best_gid is None: + best_gid = next_gid + next_gid += 1 + group_members[best_gid] = chosen_members + group_size[best_gid] = needed + for r in chosen_members: + gpu_group_id[r] = best_gid + + # ---- Step 3 – assign the sequence to every member of that group ------ + per_gpu_cost = compute_estimator(seq_len) + + packing_sequence_len[best_gid] = ( + packing_sequence_len.get(best_gid, 0) + seq_len / needed + ) + for r in chosen_members: + micro_batches[r].append(seq_len) + exec_times[r] += per_gpu_cost + sample_ids_per_gpu[r].append(sample_id) + + # Remove the sequence definitively from its bucket + buckets[bucket_idx].popleft() + + # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ + while buckets and not buckets[0]: + buckets.pop(0) + pp_cursor %= max(1, len(buckets)) + + # TODO: Removing this helps reduce the number of groups when we have + # lots of samples with same CP size. + # But because we don't exit as soon as we get balanced, + # even if there is one group available that can take the next sample, + # we will keep adding samples to the same group. + # trim_overload() does not help because it only checks if removing the + # last added sample helps. + # We cannot check after adding every sample because there will always be imbalance + # if we don't wait for future scheduling. + + # IMPORTANT: So we need a solution here + if needed < prev_needed: + # When we get into a lower CP size in the same group, + # we can start checking for balance. There is still a gotcha here. + # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. + # We keep assigning group of 2 as we do in descending order but GPU 7/15 + # never sees a microbatch assigned to it + # until we run out of samples with CP2. + # This means we are never balanced as min(exec_times) will always be 0. + # We need a smart way of identifying that we have run out of big samples + # and if we are having to assign work to a GPU already working, + # is it because there are empty GPUs? + # Would assigning work to empty GPUs first by moving onto next CP bucket help? + # But we need to remember to come back to this CP size bucket and then + # check for balance. Maybe the scheduling algorithm should look at empty + # GPUs and find work rather than going sequence by sequence. + check_balance = True + + if ( + check_balance + and buckets + and max(exec_times) - min(exec_times) <= delta * max(exec_times) + ): + break + + # Gather leftovers (flatten remaining buckets, preserve order) + leftovers = [] + for b in buckets: + for sample_seq_tuple in b: + leftovers.append(sample_seq_tuple) + + # --------------------------------------------------------------------------- + def trim_overload(): + """ + Iteratively pop the most-recent sequence from the *most-loaded group* + whenever doing so reduces the global slack. + """ + while True: + cur_max = max(exec_times) + cur_min = min(exec_times) + cur_slack = cur_max - cur_min + if cur_slack <= delta * cur_max: + # Slack is already within limit. + break + if cur_min == 0: + # There are empty GPUs that will be + # handled in the next step. + break + + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] + members = group_members[gid] + + if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: + break + + seq = micro_batches[max_r][-1] + need = group_size[gid] + per_gpu_cost = compute_estimator(seq) + + proj_times = exec_times[:] + for r in members: + proj_times[r] -= per_gpu_cost + + proj_slack = max(proj_times) - min(proj_times) + + # Check if trimming the workload helps imbalance + if proj_slack < cur_slack: + sample_id_to_remove = sample_ids_per_gpu[max_r][-1] + for r in members: + micro_batches[r].pop() + exec_times[r] -= per_gpu_cost + sample_ids_per_gpu[r].pop() + leftovers.append((sample_id_to_remove, seq)) + else: + break + + # TODO(tailaim): uncomment this to support different ranks have different num_microbatches + # trim_overload() + + # Track samples in this group before redistribution to empty GPUs + total_work_before = sum(len(mb) for mb in micro_batches) + + # Check for empty GPUs and redistribute work + def fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ): + """ + Recursively check for empty GPUs and redistribute work by increasing + the number of GPUs sharing samples. This ensures all GPUs have work. + GPUs must be allocated consecutively so we may need to push existing + work to other ranks in order to expand samples. + """ + # Find empty GPUs + empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] + if not empty_gpus: + return ( + micro_batches, + exec_times, + sample_ids_per_gpu, + group_members, + group_size, + ) # No empty GPUs, we're done + + # Find the smallest group size that exists + existing_group_sizes = set(group_size.values()) + assert ( + existing_group_sizes + ), "There should be at least one group existing, cannot reditribute, " + "try to increase 'max-seqlen-per-dp-cp-rank'." + + min_group_size = min(existing_group_sizes) + # We have Dynamic DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. + next_power = min(min_group_size * 2, total_gpus) + + # Find the first group of min_group_size that can be expanded + expandable_gid = None + expandable_members = None + expandable_new_gpus = None + + for gid, size in group_size.items(): + if size == min_group_size: + members = group_members[gid] + needed_count = next_power - min_group_size + group_start_gpu = members[0] + group_end_gpu = members[-1] + empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] + assert not all( + work for work in micro_batches[empty_gpu : empty_gpu + needed_count] + ), f"Empty GPUs were detected but not enough to expand." + work_to_push = micro_batches[ + group_end_gpu + 1 : empty_gpu + ] # This is work of all other subsequent sub-samples + exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] + + new_micro_batches = [[]] * len(micro_batches) + new_exec_times = [0.0] * len(exec_times) + new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) + + # No change in work until the group selected for expansion + for i in range(group_start_gpu): + new_micro_batches[i] = micro_batches[i] + new_exec_times[i] = exec_times[i] + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] + + # The work is distributed across the expanded group + for i in range(group_start_gpu, group_end_gpu + needed_count + 1): + new_micro_batches[i] = micro_batches[group_end_gpu] + new_exec_times[i] = self.get_total_workload( + micro_batches[group_end_gpu][0], next_power + ) + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] + + # Any assigned work on expanded GPUs is pushed + for i, work in enumerate(work_to_push): + new_micro_batches[group_end_gpu + needed_count + 1 + i] = work + new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( + sample_ids_to_push[i] + ) + + group_size[gid] = next_power + group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) + for pushed_gid in group_size.keys(): + if pushed_gid > gid: + group_members[pushed_gid] = [ + x + needed_count for x in group_members[pushed_gid] + ] + + return ( + new_micro_batches, + new_exec_times, + new_sample_ids_per_gpu, + group_members, + group_size, + ) + + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + while empty_gpus: + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( + fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ) + ) + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + + # Assert that no sample has been completely removed + total_work_after = sum(len(mb) for mb in micro_batches) + assert ( + total_work_after >= total_work_before + ), f"Samples were removed: {total_work_before} -> {total_work_after}" + + return micro_batches, leftovers, exec_times, sample_ids_per_gpu + + def align_sample_id_groups(self, sample_id_groups): + """ + Align len(sample_id_groups) to microbatch_group_size_per_vp_stage (K) when VPP is enabled. + i.e. if len(sample_id_groups) % K != 0, we need to add extra microbatches. + """ + multiple = int(self.microbatch_group_size_per_vp_stage) + remainder = (-len(sample_id_groups)) % multiple + i = len(sample_id_groups) - 1 + + def split_group(sample_id_group): + total_hdp_ranks = len(sample_id_group) + cu_ranks = [0] + prev_cp_size = 0 + + while cu_ranks[-1] != total_hdp_ranks: + start_rank = cu_ranks[-1] + sid0 = sample_id_group[start_rank][0] + # Count contiguous ranks that share sid0 (CP group assumed contiguous). + cp_size = 0 + for r in range(start_rank, total_hdp_ranks): + if sid0 in sample_id_group[r]: + cp_size += 1 + else: + break + assert ( + prev_cp_size == 0 or cp_size <= prev_cp_size + ), f"split_group: CP size is not decreasing: prev={prev_cp_size}, cur={cp_size}" + cu_ranks.append(start_rank + cp_size) + prev_cp_size = cp_size + if len(cu_ranks) == 2: + # can't split anymore + return None, None + + k = 0 + while cu_ranks[k] < total_hdp_ranks // 2: + k += 1 + + # Keep original rank positions; zero out the other half, then expand CP to fill empties. + old_mb = sample_id_group[: cu_ranks[k]] + [ + [] for _ in range(total_hdp_ranks - cu_ranks[k]) + ] + new_mb = sample_id_group[cu_ranks[k] :] + [[] for _ in range(cu_ranks[k])] + old_mb = fill_empty_by_expanding_cp(old_mb) + new_mb = fill_empty_by_expanding_cp(new_mb) + return new_mb, old_mb + + def fill_empty_by_expanding_cp(sample_id_group): + def fill_empty(sample_id_group): + empty_size = sum(1 for x in sample_id_group if len(x) == 0) + i = len(sample_id_group) - 1 - empty_size + prev_cp_size = 0 + while i >= 0: + sid0 = sample_id_group[i][0] + cp_size = 0 + while sid0 in sample_id_group[i] and i >= 0: + cp_size += 1 + i -= 1 + if cp_size > prev_cp_size and prev_cp_size != 0: + # double cp_size of this group + start_idx = i + 1 + cp_size + end_idx = ( + -empty_size + prev_cp_size if -empty_size + prev_cp_size < 0 else None + ) + sample_id_group[start_idx + 2 * prev_cp_size : end_idx] = sample_id_group[ + start_idx + prev_cp_size : -empty_size + ] + sample_id_group[start_idx + prev_cp_size : start_idx + 2 * prev_cp_size] = ( + sample_id_group[start_idx : start_idx + prev_cp_size] + ) + break + elif cp_size <= empty_size and i == -1: + end_idx = -empty_size + cp_size if -empty_size + cp_size < 0 else None + sample_id_group[2 * cp_size : end_idx] = sample_id_group[ + cp_size:-empty_size + ] + sample_id_group[cp_size : 2 * cp_size] = sample_id_group[0:cp_size] + break + prev_cp_size = cp_size + return sample_id_group + + while len(sample_id_group[-1]) == 0: + sample_id_group = fill_empty(sample_id_group) + return sample_id_group + + attempts_since_split = 0 + while remainder > 0: + if i < 0: + if attempts_since_split >= len(sample_id_groups): + assert ( + False + ), f'align_sample_id_groups: no tail microbatch has enough ids to split' + i = len(sample_id_groups) - 1 + group1, group2 = split_group(sample_id_groups[i]) + if group1 is not None and group2 is not None: + sample_id_groups[i] = group1 + sample_id_groups.append(group2) + remainder -= 1 + attempts_since_split = 0 + else: + attempts_since_split += 1 + i -= 1 + + return sample_id_groups + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ + groups = [] + sample_id_groups = [] + # We assign a sample_id to each sub-sample in order to track assignment to each GPU. + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + while sample_id_seqlens: + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( + sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus + ) + groups.append(mb) + sample_id_groups.append(sample_ids) + + if ( + self.microbatch_group_size_per_vp_stage is not None + and self.microbatch_group_size_per_vp_stage > 1 + ): + sample_id_groups = self.align_sample_id_groups(sample_id_groups) + + return sample_id_groups + + +def get_batch_on_this_rank_for_sequence_packing( + data_iterator, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + dynamic_context_parallel: bool = False, +): + """ + Get a batch of data for sequence packing. + Args: + data_iterator (Iterator): The data iterator to get the batch from. + mtp_on_this_rank (bool): Whether to use multi-token prediction. + vp_stage (Optional[int]): The stage of the pipeline. + dynamic_context_parallel (bool): Whether to use dynamic context parallel. + Returns: + tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) + """ + + def _broadcast_to_tp_group(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_tensor_model_parallel_src_rank(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + is_tp_rank_0 = parallel_state.get_tensor_model_parallel_rank() == 0 + is_first_stage = parallel_state.is_pipeline_first_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_last_stage = parallel_state.is_pipeline_last_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_first_or_last_stage = is_first_stage or is_last_stage + dev = torch.cuda.current_device() + + # data_iterator should return a batch including the following keys. + batch_keys = [ + 'cu_seqlens', + 'cu_seqlens_padded', + 'max_seqlen', + ] + if dynamic_context_parallel: + batch_keys.append('local_cp_size') + if is_first_stage: + batch_keys.append('tokens') + batch_keys.append('position_ids') + if is_last_stage: + batch_keys.append('labels') + batch_keys.append('loss_mask') + + # Get a batch from data_iterator or create an emtpy batch. + if is_tp_rank_0: + assert data_iterator is not None + batch = next(data_iterator) + for key in batch_keys: + assert key in batch, f"{key} is missing in current batch" + else: + assert data_iterator is None, "Non TP 0 rank should not have data_iterator" + batch = {} + + # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only + # TP rank 0 and the first/last PP stage rank has these data. + if is_tp_rank_0 and is_first_or_last_stage: + # Get the proper cp_size and cp_rank based on dynamic context parallel is enabled or not. + if dynamic_context_parallel: + cp_size = batch['local_cp_size'] + if type(cp_size) == torch.Tensor: + cp_size = cp_size.item() + cp_rank = torch.distributed.get_rank( + group=parallel_state.get_dynamic_data_context_parallel_groups(group_size=cp_size) + ) + else: + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + # If cp_size == 1, no need to do further processing. + if cp_size > 1: + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + batch[key] = batch[key].index_select(0, index) + + # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and + # cu_seqlens_padded for non TP 0 ranks. + if is_tp_rank_0: + cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) + else: + cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_to_tp_group(cu_seqlen_size) + cu_seqlen_size = cu_seqlen_size.item() + + # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, + # labels, loss_mask for non TP 0 ranks. Only first or last stage need this. + if is_first_or_last_stage: + if is_tp_rank_0: + total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) + else: + total_tokens = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_to_tp_group(total_tokens) + total_tokens = total_tokens.item() + + # Step1: Prepare "tokens", "position_ids" on all ranks. + if is_first_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['tokens'].dtype == torch.int64 + assert batch['position_ids'].dtype == torch.int64 + batch['tokens'] = batch['tokens'].view(1, total_tokens) + batch['position_ids'] = batch['position_ids'].view(1, total_tokens) + else: + batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + else: + # Non first stage rank doesn't need tokens and position_ids. + batch['tokens'] = None + batch['position_ids'] = None + + # Step2: Prepare "labels", "loss_mask" on all ranks. + if is_last_stage: + if is_tp_rank_0: + assert batch['labels'].dtype == torch.int64 + assert batch['loss_mask'].dtype == torch.float32 + batch['labels'] = batch['labels'].view(1, total_tokens) + batch['loss_mask'] = batch['loss_mask'].view(1, total_tokens) + else: + batch['labels'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['loss_mask'] = torch.empty([1, total_tokens], dtype=torch.float32, device=dev) + else: + # Non last stage rank doesn't need labels and loss_mask. + batch['labels'] = None + batch['loss_mask'] = None + + # Step3: Prepare "cu_seqlens", "cu_seqlens_padded", "max_seqlen" on all ranks. + if is_tp_rank_0: + assert batch['cu_seqlens'].dtype == torch.int32 + assert batch['cu_seqlens_padded'].dtype == torch.int32 + assert batch['cu_seqlens'].dim() == 1 + assert batch['cu_seqlens_padded'].dim() == 1 + if type(batch['max_seqlen']) == int: + batch['max_seqlen'] = torch.tensor(batch['max_seqlen'], dtype=torch.int32, device=dev) + else: + assert batch['max_seqlen'].dtype == torch.int32 + assert batch['max_seqlen'].numel() == 1 + else: + batch['cu_seqlens'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Step4(optional): Prepare "local_cp_size" if dynamic context parallel is enabled. + if dynamic_context_parallel: + if is_tp_rank_0: + if type(batch['local_cp_size']) == int: + batch['local_cp_size'] = torch.tensor( + batch['local_cp_size'], dtype=torch.int32, device=dev + ) + else: + assert batch['local_cp_size'].dtype == torch.int32 + assert batch['local_cp_size'].numel() == 1 + else: + batch['local_cp_size'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Broadcast batch inside TP group. + _broadcast_to_tp_group(batch['tokens']) + _broadcast_to_tp_group(batch['position_ids']) + _broadcast_to_tp_group(batch['labels']) + _broadcast_to_tp_group(batch['loss_mask']) + _broadcast_to_tp_group(batch['cu_seqlens']) + _broadcast_to_tp_group(batch['cu_seqlens_padded']) + _broadcast_to_tp_group(batch['max_seqlen']) + if dynamic_context_parallel: + _broadcast_to_tp_group(batch['local_cp_size']) + + # Extract the data from batch after broadcasting. + tokens = batch['tokens'] + position_ids = batch['position_ids'] + labels = batch['labels'] + loss_mask = batch['loss_mask'] + cu_seqlens = batch['cu_seqlens'] + cu_seqlens_padded = batch['cu_seqlens_padded'] + max_seqlen = batch['max_seqlen'].item() + + # Set the proper cp_group and local_cp_size when dynamic context parallel is enabled. + if dynamic_context_parallel: + local_cp_size = batch['local_cp_size'].item() + cp_group = parallel_state.get_dynamic_data_context_parallel_groups(group_size=local_cp_size) + else: + local_cp_size = None + cp_group = None + + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to + # get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + local_cp_size=local_cp_size, + cp_group=cp_group, + ) + + # "attention_mask" is not valid for sequence packing, so set it to None. + return tokens, labels, loss_mask, None, position_ids, packed_seq_params diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index a2d39a6d688..a8df0fc03d2 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -60,13 +60,19 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): Set to 0 if sequence parallel is not enabled regardless of TP size. """ - hybrid_context_parallel: bool = False - """Option to enable hybrid context parallelism. When setting this to True, + dynamic_context_parallel: bool = False + """Option to enable dynamic context parallelism. When setting this to True, each sample should be divisible by the data parallel size * context parallel size * 2. If sequence parallel is enabled, it should be divisible by the data parallel size * context parallel size * sequence parallel size * 2. """ + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + + sequence_packing: bool = False + """Option to enable sequence packing for training.""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index d823e42b0bc..1b5c0e5d5c8 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1229,6 +1229,13 @@ def __init__( else: extra_kwargs["cp_comm_type"] = cp_comm_type + # we need to create a single stream for cp=1 and dynamic cp enabled. + if ( + self.config.dynamic_context_parallel + and getattr(TEDotProductAttention, "cp_stream") is None + ): + TEDotProductAttention.cp_stream = torch.cuda.Stream() + if self.config.deterministic_mode: if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: raise RuntimeError( @@ -1322,21 +1329,17 @@ def forward( """Forward.""" if packed_seq_params is not None: # If Dynamic CP group is provided, update TE DPA CP group - if packed_seq_params.cp_group is not None: - self.cp_group = packed_seq_params.cp_group - super().set_context_parallel_group( - self.cp_group, - torch.distributed.get_process_group_ranks(self.cp_group), - TEDotProductAttention.cp_stream, - self.cp_comm_type, - ) - # If cp_group is None but local_cp_size is provided, - # Indicates to turn off CP dynamically - elif packed_seq_params.local_cp_size is not None: - assert ( - packed_seq_params.local_cp_size == 1 - ), "local_cp_size must be == 1 if provided without cp_group" - super().set_context_parallel_group(None, None, None, self.cp_comm_type) + if packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size == 1: + super().set_context_parallel_group(None, None, None, self.cp_comm_type) + else: + self.cp_group = packed_seq_params.cp_group + super().set_context_parallel_group( + self.cp_group, + torch.distributed.get_process_group_ranks(self.cp_group), + TEDotProductAttention.cp_stream, + self.cp_comm_type, + ) self.kept_packed_seq_params.discard("cp_group") self.kept_packed_seq_params.discard("local_cp_size") diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 4452bdf360b..e2b35c04138 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -6,7 +6,14 @@ import torch -from megatron.core.utils import experimental_api +from megatron.core.utils import experimental_api, get_te_version, is_te_min_version + +try: + from packaging.version import Version as PkgVersion + + HAVE_PACKAGING = True +except ImportError: + HAVE_PACKAGING = False @dataclass @@ -62,14 +69,32 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when using sequence_packing. """ - hybrid_context_parallel: bool = False + dynamic_context_parallel: bool = False """ - If true, enables hybrid context parallel. This is used to balance the workload of + If true, enables dynamic context parallel. This is used to balance the workload of each CP rank when we use packed samples with variable sequence lengths. - Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. + Please set max_seqlen_per_dp_cp_rank when using dynamic_context_parallel. + When enabling dynamic_context_parallel, sequence_packing must be true. + """ + + sequence_packing_scheduler: Optional[str] = None + """ + Scheduler for sequence packing and dynamic context parallel. + naive_sequence_packing: default naive sequence packing scheduler(just THD, no Dynamic-CP, this + is just for comparison with default dynamic-cp scheduler, not recommended for production) + default_dynamic_cp: default dynamic-cp scheduler for dynamic context parallel provided by MCore. + empty_scheduler_with_packing: scheduling is already handled by the data sampler, + this scheduler only performs packing. + empty_scheduler_no_packing: scheduling and packing are already handled by the data sampler, + this scheduler only returns the batch. + """ + + sequence_packing: bool = False + """ + If true, enables sft sequence packing. """ expert_model_parallel_size: int = 1 @@ -438,3 +463,23 @@ def __post_init__(self): "Pipeline parallel communication overlapping in warmup and flush is only " "compatible with overlap_p2p_comm but not batch_p2p_comm." ) + if self.dynamic_context_parallel and not self.sequence_packing: + raise ValueError("Dynamic context parallel requires sequence packing to be enabled") + if self.sequence_packing: + if not HAVE_PACKAGING: + raise ImportError( + "packaging is not installed. Please install it with `pip install packaging`." + ) + # TODO: remove this after we fix the convergence issue with TE < 2.9. + if not ( + is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a") + ): + raise ValueError( + "SFT sequence packing requires Transformer Engine >= 2.9.0 " + f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)." + ) + if self.sequence_packing_scheduler == None: + if self.dynamic_context_parallel: + self.sequence_packing_scheduler = "default_dynamic_cp" + else: + self.sequence_packing_scheduler = "naive_sequence_packing" diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 7aa867fd98f..fa15d8e0946 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -115,8 +115,8 @@ _CONTEXT_PARALLEL_GLOBAL_RANKS = None # Hierarchical context parallel groups _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = None -# Hybrid context parallel groups -_HYBRID_DP_CP_GROUPS = {} +# Dynamic context parallel groups +_DYNAMIC_DP_CP_GROUPS = {} # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None @@ -420,13 +420,13 @@ def create_hierarchical_groups( return hierarchical_groups, hierarchical_groups_gloo -def create_hybrid_dp_cp_groups(rank, ranks, pg_options): +def create_dynamic_dp_cp_groups(rank, ranks, pg_options): """ - Creates groups required for hybrid DPxCP. + Creates groups required for dynamic DPxCP. Creates a new group for every power of 2 up to the number of DPxCP ranks. Returns a dictionary indexed by group size. """ - hybrid_dp_cp_groups = {} + dynamic_dp_cp_groups = {} # Generate group for every power of 2 up to the number of CP ranks # We limit the allowed group sizes in order to avoid excessive overhead. group_sizes = [2**i for i in range(int(log2(len(ranks))))][1:] @@ -435,14 +435,14 @@ def create_hybrid_dp_cp_groups(rank, ranks, pg_options): group = create_group( ranks[i : i + group_size], pg_options=pg_options, - group_desc=f"HYBRID_DP_CP_GROUP_{group_size}", + group_desc=f"DYNAMIC_DP_CP_GROUP_{group_size}", ) if rank in ranks[i : i + group_size]: assert ( - group_size not in hybrid_dp_cp_groups - ), f"Rank {rank} appears in multiple Hybrid DP CP groups of size {group_size}" - hybrid_dp_cp_groups[group_size] = group - return hybrid_dp_cp_groups + group_size not in dynamic_dp_cp_groups + ), f"Rank {rank} appears in multiple Dynamic DP CP groups of size {group_size}" + dynamic_dp_cp_groups[group_size] = group + return dynamic_dp_cp_groups class RankGenerator(object): @@ -565,7 +565,8 @@ def initialize_model_parallel( create_gloo_process_groups: bool = True, high_priority_stream_groups: Optional[List[str]] = None, sharp_enabled_group: Optional[str] = None, - hybrid_context_parallel: bool = False, + dynamic_context_parallel: bool = False, + min_dynamic_context_parallel_size: int = 1, ) -> None: """Initialize model data parallel groups. @@ -917,18 +918,18 @@ def initialize_model_parallel( if "NCCL_COLLNET_ENABLE" in os.environ: del os.environ["NCCL_COLLNET_ENABLE"] - if hybrid_context_parallel: - global _HYBRID_DP_CP_GROUPS + if dynamic_context_parallel: + global _DYNAMIC_DP_CP_GROUPS for ranks_with_cp in decoder_rank_generator.get_ranks('dp-cp'): assert ( len(ranks_with_cp) % 2 == 0 - ), "Hybrid context parallel requires an even number of ranks" - _HYBRID_DP_CP_GROUPS.update( - create_hybrid_dp_cp_groups( + ), "Dynamic context parallel requires an even number of ranks" + _DYNAMIC_DP_CP_GROUPS.update( + create_dynamic_dp_cp_groups( rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs) ) ) - # TODO: Are gloo groups needed for hybrid cp? + # TODO: Are gloo groups needed for dynamic cp? for ranks in decoder_rank_generator.get_ranks('dp'): group = create_group( @@ -977,6 +978,22 @@ def initialize_model_parallel( if rank in ranks: _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups + if dynamic_context_parallel: + # PyTorch is performing lazy initialization of the communicator group. + # Therefore, we need to perform a nccl call to ensure that the communicator group is created. + group_sizes = [ + 2**i + for i in range( + int(log2(min_dynamic_context_parallel_size)), int(log2(data_parallel_size)) + ) + ] + if group_sizes[-1] * 2 == data_parallel_size: + group_sizes.append(data_parallel_size) + for group_size in group_sizes: + group = get_dynamic_data_context_parallel_groups(group_size=group_size) + torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()]) + torch.cuda.synchronize() + # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GLOBAL_RANKS @@ -1446,16 +1463,20 @@ def get_hierarchical_context_parallel_groups(check_initialized=True): return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS -def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=None): - """Get the hybrid context parallel groups the caller rank belongs to.""" +def get_dynamic_data_context_parallel_groups(check_initialized=True, group_size=None): + """Get the dynamic context parallel groups the caller rank belongs to.""" # If the group size is the same as the entire DPxCP group, return the original group if get_data_parallel_world_size(with_context_parallel=True) == group_size: if check_initialized: assert _DATA_PARALLEL_GROUP_WITH_CP is not None return _DATA_PARALLEL_GROUP_WITH_CP + elif group_size == 1: + if check_initialized: + assert _CONTEXT_PARALLEL_GROUP is not None + return _CONTEXT_PARALLEL_GROUP if check_initialized: - assert _HYBRID_DP_CP_GROUPS is not None - return _HYBRID_DP_CP_GROUPS[group_size] + assert _DYNAMIC_DP_CP_GROUPS is not None + return _DYNAMIC_DP_CP_GROUPS[group_size] def get_embedding_group(check_initialized=True): diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py deleted file mode 100644 index 27b5fc87945..00000000000 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ /dev/null @@ -1,660 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from collections import deque -from functools import lru_cache -from math import ceil, log2 -from typing import Callable, List, Optional, Tuple - -import torch - -from megatron.core import parallel_state -from megatron.core.rerun_state_machine import RerunDataIterator - - -class BalancedCPScheduler: - """ - This class provides the functionality to form groups of sub-samples - such that all DPxCP ranks have a roughly balanced workload in the group. - """ - - def __init__(self, max_seq_len_per_rank: int, dp_cp_group: torch.distributed.ProcessGroup): - self.max_seq_len_per_rank = max_seq_len_per_rank - self.num_subsamples = 0 - self.num_subsamples_processed = 0 - self.free_resources = [] - self.total_hdp_gpus = dp_cp_group.size() - - @lru_cache(maxsize=128) - def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): - """ - seq_length: sequence length of a sub-sample - cp_size: total number of CP ranks working on this sub-sample - - Note: - This function is used to estimate the relative workload intensity - of a sub-sample. This is not meant to be an accurate flops calculator. - - Returns: workload of a sub-sample - """ - if cp_size is None: - cp_size = self.gpus_needed(seq_length) - return (seq_length * seq_length) / cp_size - - @lru_cache(maxsize=128) - def gpus_needed(self, seq_len: int) -> int: - """ - Calculates the number of GPUs needed for a given sequence length - and max sequence length per CP rank. - This is used to determine the CP size of a sub-sample. - - The number is rounded up to the next power of 2 to match the available - hybrid context parallel process group sizes. - """ - return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) - - def make_buckets_equal( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - ) -> List[deque]: - """ - Makes as many buckets as unique CP sizes needed. - This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. - """ - # Extract just the sequence lengths for determining k - seqlens = [seq_len for _, seq_len in sample_seqlens] - - # Determine k based on unique GPU categories needed - k = len({self.gpus_needed(L) for L in seqlens}) - - # Create a work target for each bucket - # This is the total work divided by the number of buckets - work = [] - for _, s in sample_seqlens: - cp_size = self.gpus_needed(s) - work.append(compute_estimator(s, cp_size)) - total_work = sum(work) - target = total_work / k - buckets, cur, cur_work = [], [], 0.0 - remaining_work = total_work - remaining_k = k - - for i, (sample_id, seq_len) in enumerate(sample_seqlens): - work = compute_estimator(seq_len) - projected = cur_work + work - - # Check if we should close this bucket - if cur and ( - projected > target * 1.1 # Too much work - or len(sample_seqlens) - i <= remaining_k - len(buckets) - ): # Need to save sequences for remaining buckets - buckets.append(deque(cur)) - cur, cur_work = [], 0.0 - remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) - remaining_k -= 1 - - cur.append((sample_id, seq_len)) - cur_work += work - - if cur: - buckets.append(deque(cur)) - - return buckets - - def next_hdp_group( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - total_gpus: int, - delta: float = 0.05, # balance slack (e.g. 5 %) - strategy: str = "dp", # "dp" or "pp" - eps_bucket: float = 0.10, # ε target for bucket balance - ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: - """ - Given a list of (sample_id, sequence_length) tuples, this function aims to assign - sequences in a group such that all GPUs in the DPxCP group have a roughly balanced - workload. Once each group is roughly balanced, we exit and return the - group and the leftover sequences. - - The function performs the following passes in order to form a balanced microbatch: - 1. We create buckets of sequences that are roughly balanced. - We try to create as many buckets as possible CP sizes. - 2. Given a bucket has sequences available, we assign the sample - a. To a new set of GPUs if there are enough free GPUs. - b. To an existing set of GPUs with the lowest load. - 3. We check if the group is balanced whenever we need to move onto a new CP size - in the same set of GPUs. - 4. We trim the group if removing the last added sequence helps improve balance. - 5. If we run out of sequences to assign and there are empty GPUs, - we redistribute work to empty GPUs by recursively increasing the CP size of a - sample until no empty GPUs are left. - - Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). - """ - if not sample_seqlens: - return ( - [[] for _ in range(total_gpus)], - [], - [0.0 for _ in range(total_gpus)], - [[] for _ in range(total_gpus)], - ) - - # Get buckets of sequences with balanced work - buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) - - # Initialize tracking structures - micro_batches = [[] for _ in range(total_gpus)] - exec_times = [0.0 for _ in range(total_gpus)] - sample_ids_per_gpu = [[] for _ in range(total_gpus)] - - gpu_group_id = [None] * total_gpus - group_members = {} - group_size = {} - next_gid = 0 - - pp_cursor = 0 - prev_needed = None - check_balance = False - - while buckets: - # ---- Step 1 – pick the next sequence we COULD place ------------------ - sample_seq_tuple = bucket_idx = None - needed = None - - scan_order = ( - range(len(buckets)) - if strategy == "dp" - else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] - ) - - for idx in scan_order: - if not buckets[idx]: - continue - cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) - cand_seq_len = cand_tuple[1] - needed = self.gpus_needed(cand_seq_len) - - # (a) Do we have an *existing* group of size `needed`? - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - - # (b) Or enough completely free GPUs to start a new group? - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if candidate_gids or len(free_ranks) >= needed: - sample_seq_tuple, bucket_idx = cand_tuple, idx - break - - # No place to put any remaining sequence – finish this micro‑batch - if sample_seq_tuple is None: - break - - # TODO[pmannan]: PP not yet supported. Add PP scheduling. - if strategy == "pp": - pp_cursor = (bucket_idx + 1) % len(buckets) - - sample_id, seq_len = sample_seq_tuple - needed = self.gpus_needed(seq_len) - if prev_needed is None: - prev_needed = needed - - # (a) Existing groups of exactly this size - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - if candidate_gids: - best_gid, best_load = min( - ( - (gid, max(exec_times[r] for r in group_members[gid])) - for gid in candidate_gids - ), - key=lambda t: t[1], - ) - else: - best_gid, best_load = None, float("inf") - - # (b) Hypothetical **new** group from completely free GPUs - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if len(free_ranks) >= needed: - free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) - new_members = free_sorted[:needed] - new_load = exec_times[new_members[-1]] - - if new_load < best_load: - best_gid = None - chosen_members = new_members - else: - chosen_members = group_members[best_gid] - else: - chosen_members = group_members[best_gid] - - # ---- Step 2 – if we decided to create a fresh group ---------------- - if best_gid is None: - best_gid = next_gid - next_gid += 1 - group_members[best_gid] = chosen_members - group_size[best_gid] = needed - for r in chosen_members: - gpu_group_id[r] = best_gid - - # ---- Step 3 – assign the sequence to every member of that group ------ - per_gpu_cost = compute_estimator(seq_len) - - for r in chosen_members: - micro_batches[r].append(seq_len) - exec_times[r] += per_gpu_cost - sample_ids_per_gpu[r].append(sample_id) - - # Remove the sequence definitively from its bucket - buckets[bucket_idx].popleft() - - # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ - while buckets and not buckets[0]: - buckets.pop(0) - pp_cursor %= max(1, len(buckets)) - - # TODO: Removing this helps reduce the number of groups when we have - # lots of samples with same CP size. - # But because we don't exit as soon as we get balanced, - # even if there is one group available that can take the next sample, - # we will keep adding samples to the same group. - # trim_overload() does not help because it only checks if removing the - # last added sample helps. - # We cannot check after adding every sample because there will always be imbalance - # if we don't wait for future scheduling. - - # IMPORTANT: So we need a solution here - if needed < prev_needed: - # When we get into a lower CP size in the same group, - # we can start checking for balance. There is still a gotcha here. - # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. - # We keep assigning group of 2 as we do in descending order but GPU 7/15 - # never sees a microbatch assigned to it - # until we run out of samples with CP2. - # This means we are never balanced as min(exec_times) will always be 0. - # We need a smart way of identifying that we have run out of big samples - # and if we are having to assign work to a GPU already working, - # is it because there are empty GPUs? - # Would assigning work to empty GPUs first by moving onto next CP bucket help? - # But we need to remember to come back to this CP size bucket and then - # check for balance. Maybe the scheduling algorithm should look at empty - # GPUs and find work rather than going sequence by sequence. - check_balance = True - - if ( - check_balance - and buckets - and max(exec_times) - min(exec_times) <= delta * max(exec_times) - ): - break - - # Gather leftovers (flatten remaining buckets, preserve order) - leftovers = [] - for b in buckets: - for sample_seq_tuple in b: - leftovers.append(sample_seq_tuple) - - # --------------------------------------------------------------------------- - def trim_overload(): - """ - Iteratively pop the most‑recent sequence from the *most‑loaded group* - whenever doing so reduces the global slack. - """ - while True: - cur_max = max(exec_times) - cur_min = min(exec_times) - cur_slack = cur_max - cur_min - if cur_slack <= delta * cur_max: - # Slack is already within limit. - break - if cur_min == 0: - # There are empty GPUs that will be - # handled in the next step. - break - - max_r = exec_times.index(cur_max) - gid = gpu_group_id[max_r] - members = group_members[gid] - - if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: - break - - seq = micro_batches[max_r][-1] - need = group_size[gid] - per_gpu_cost = compute_estimator(seq) - - proj_times = exec_times[:] - for r in members: - proj_times[r] -= per_gpu_cost - - proj_slack = max(proj_times) - min(proj_times) - - # Check if trimming the workload helps imbalance - if proj_slack < cur_slack: - sample_id_to_remove = sample_ids_per_gpu[max_r][-1] - for r in members: - micro_batches[r].pop() - exec_times[r] -= per_gpu_cost - sample_ids_per_gpu[r].pop() - leftovers.append((sample_id_to_remove, seq)) - else: - break - - trim_overload() - - # Track samples in this group before redistribution to empty GPUs - total_work_before = sum(len(mb) for mb in micro_batches) - - # Check for empty GPUs and redistribute work - def fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ): - """ - Recursively check for empty GPUs and redistribute work by increasing - the number of GPUs sharing samples. This ensures all GPUs have work. - GPUs must be allocated consecutively so we may need to push existing - work to other ranks in order to expand samples. - """ - # Find empty GPUs - empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] - if not empty_gpus: - return ( - micro_batches, - exec_times, - sample_ids_per_gpu, - group_members, - group_size, - ) # No empty GPUs, we're done - - # Find the smallest group size that exists - existing_group_sizes = set(group_size.values()) - assert ( - existing_group_sizes - ), "There should be at least one group existing, cannot reditribute, " - "try to increase 'max-seqlen-per-cp-rank'." - - min_group_size = min(existing_group_sizes) - # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. - next_power = min(min_group_size * 2, total_gpus) - - # Find the first group of min_group_size that can be expanded - expandable_gid = None - expandable_members = None - expandable_new_gpus = None - - for gid, size in group_size.items(): - if size == min_group_size: - members = group_members[gid] - needed_count = next_power - min_group_size - group_start_gpu = members[0] - group_end_gpu = members[-1] - empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] - assert not all( - work for work in micro_batches[empty_gpu : empty_gpu + needed_count] - ), f"Empty GPUs were detected but not enough to expand." - work_to_push = micro_batches[ - group_end_gpu + 1 : empty_gpu - ] # This is work of all other subsequent sub-samples - exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] - sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] - - new_micro_batches = [[]] * len(micro_batches) - new_exec_times = [0.0] * len(exec_times) - new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) - - # No change in work until the group selected for expansion - for i in range(group_start_gpu): - new_micro_batches[i] = micro_batches[i] - new_exec_times[i] = exec_times[i] - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] - - # The work is distributed across the expanded group - for i in range(group_start_gpu, group_end_gpu + needed_count + 1): - new_micro_batches[i] = micro_batches[group_end_gpu] - new_exec_times[i] = self.get_total_workload( - micro_batches[group_end_gpu][0], next_power - ) - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] - - # Any assigned work on expanded GPUs is pushed - for i, work in enumerate(work_to_push): - new_micro_batches[group_end_gpu + needed_count + 1 + i] = work - new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] - new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( - sample_ids_to_push[i] - ) - - group_size[gid] = next_power - group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) - for pushed_gid in group_size.keys(): - if pushed_gid > gid: - group_members[pushed_gid] = [ - x + needed_count for x in group_members[pushed_gid] - ] - - return ( - new_micro_batches, - new_exec_times, - new_sample_ids_per_gpu, - group_members, - group_size, - ) - - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - while empty_gpus: - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( - fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ) - ) - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - - # Assert that no sample has been completely removed - total_work_after = sum(len(mb) for mb in micro_batches) - assert ( - total_work_after >= total_work_before - ), f"Samples were removed: {total_work_before} -> {total_work_after}" - - return micro_batches, leftovers, exec_times, sample_ids_per_gpu - - def get_groups_and_subsamples(self, sample_id_seqlens, config): - """ - This function recursively forms groups of sub-samples such that all DPxCP ranks - have a roughly balanced workload in the group. - """ - groups = [] - sample_id_groups = [] - # We assign a sample_id to each sub-sample in order to track assignment to each GPU. - sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) - while sample_id_seqlens: - mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( - sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus - ) - groups.append(mb) - if len(sample_ids) < self.total_hdp_gpus: - sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) - sample_id_groups.append(sample_ids) - - return groups, sample_id_groups - - -def hybrid_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, -): - """ - Scheduler for Hybrid Context Parallel. - - This function performs the packed sample scheduling and determines - 1. The number of microbatches to schedule for each CP rank - 2. The number of groups each CP rank should execute - 3. The number of sub-samples per group each CP rank should execute - - A group is defined by a set of samples that can run across the CP domain without any barrier. - There are many reasons why we may not be able to run endless samples within a single group. - For example, if we have 8 GPUs, - if GPU 0-5 are assigned a long sample that requires CP6, - GPU 6-7 are assigned a short sample that requires CP2, - The next sample which requires CP4 can be assigned GPU 4-7. - But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. - """ - from .schedules import backward_step, forward_step - - def _broadcast(item): - if item is not None: - torch.distributed.broadcast( - item, - parallel_state.get_tensor_model_parallel_src_rank(), - group=parallel_state.get_tensor_model_parallel_group(), - ) - - def _broadcast_num_samples_this_group(num_samples_this_group): - dev = torch.cuda.current_device() - torch.distributed.barrier() - - n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel()) - n = torch.tensor([n], dtype=torch.int64, device=dev) - - _broadcast(n) - n = int(n.item()) - - assert n > 0, "there should be at least 1 sub samples in the group" - num_samples_this_group_broadcast = ( - torch.empty(n, dtype=torch.int32, device=dev) - if num_samples_this_group is None - else num_samples_this_group - ) - _broadcast(num_samples_this_group_broadcast) - return num_samples_this_group_broadcast - - def _get_new_data_iterator(sample_id_in_group, group_id): - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[sample_id_in_group] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[group_id] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - return new_data_iterator - else: - return None - - # We get data once per global batch and schedule the sub-samples. - # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? - hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 - - if is_first_tp_rank: - data = next(data_iterator) - sample_id_groups = data[1] - batch = data[0] - else: - data, sample_id_groups, batch = None, None, None - - num_samples_this_group = None - if is_first_tp_rank: - num_samples_this_group = torch.tensor( - [len(group[hdp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' - ) - - num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) - num_samples_this_group = num_samples_this_group.cpu().numpy() - num_total_groups = num_samples_this_group.shape[0] - - current_microbatch = 0 - - # Upto last group, we don't need any sync. - with no_sync_func(): - for j in range(num_total_groups - 1): - sample_ids_this_group = sample_id_groups[j][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[j]): - # Call forward step for each sub-sample - new_data_iterator = _get_new_data_iterator(i, j) - # TODO: Find the usage of current_microbatch and is_first_microbatch and - # how that may affect my usage. - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - # Create a barrier at end of each group. - # This barrier ensures that all ranks are prepared to change assigned CP group sizes and - # no rank is starting a sub-sample ahead of it's partner ranks. - torch.distributed.barrier( - parallel_state.get_data_parallel_group(with_context_parallel=True) - ) - - # For the last group, we need to run the last sub-sample out of the context handler. - with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[-1] - 1): - new_data_iterator = _get_new_data_iterator(i, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - # The last sub-sample of the last group of the last microbatch is - # run out of the context handler. - new_data_iterator = _get_new_data_iterator(-1, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - return forward_data_store, total_num_tokens diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 9dc79ed11f7..018b532cfa2 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -9,6 +9,7 @@ from torch.autograd.variable import Variable from megatron.core import parallel_state +from megatron.core.datasets.data_schedule import PackingScheduler, wrap_dataloader from megatron.core.enums import ModelType from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_reset, @@ -37,7 +38,6 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) -from .hybrid_cp_schedule import hybrid_context_parallel_forward_backward # Types Shape = Union[List[int], torch.Size] @@ -512,6 +512,31 @@ def check_first_val_step(first_val_step, forward_only, cond): return cond +def wrap_iterator_helper( + config, + data_iterator: Union[Iterator, List[Iterator]], + num_microbatches: int, + pg_collection: Optional[ProcessGroupCollection] = None, +): + """Warp data iterator for sequence packing if needed.""" + if config.sequence_packing: + scheduler_type_map = { + 'default_dynamic_cp': PackingScheduler.DEFAULT_DYNAMIC_CP, + 'empty_scheduler_with_packing': PackingScheduler.EMPTY_PACKING, + 'empty_scheduler_no_packing': PackingScheduler.EMPTY_NO_PACKING, + 'naive_sequence_packing': PackingScheduler.NAIVE_SEQUENCE_PACKING, + } + if config.sequence_packing_scheduler not in scheduler_type_map: + raise ValueError( + f"Invalid sequence packing scheduler: \ + {config.sequence_packing_scheduler}" + ) + scheduler_type = scheduler_type_map[config.sequence_packing_scheduler] + return wrap_dataloader(data_iterator, config, scheduler_type, pg_collection=None) + else: + return data_iterator, num_microbatches, None, None + + def forward_backward_no_pipelining( *, forward_step_func, @@ -594,6 +619,13 @@ def forward_backward_no_pipelining( input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") + ( + data_iterator, + num_microbatches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + if config.overlap_moe_expert_parallel_comm and not forward_only: forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining( forward_step_func, @@ -611,24 +643,6 @@ def forward_backward_no_pipelining( total_num_tokens, partial(check_first_val_step, first_val_step, forward_only), ) - elif config.hybrid_context_parallel: - forward_data_store, total_num_tokens = hybrid_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, - ) else: with no_sync_func(): for i in range(num_microbatches - 1): @@ -692,6 +706,11 @@ def forward_backward_no_pipelining( ): create_cudagraphs() + if config.sequence_packing and not forward_only: + forward_data_store.append( + [num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch] + ) + return forward_data_store @@ -1048,6 +1067,13 @@ def forward_backward_pipelining_with_interleaving( if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") + ( + data_iterator, + num_microbatches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: # vp is ignored for clear_embedding_activation_buffer @@ -2064,6 +2090,11 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): create_cudagraphs() nvtx_range_pop(suffix="misc") + if config.sequence_packing and not forward_only: + forward_data_store.append( + [num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch] + ) + return forward_data_store @@ -2181,6 +2212,13 @@ def forward_backward_pipelining_without_interleaving( "provide none or provide all the process groups" ) + ( + data_iterator, + num_microbatches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: embedding_module = clear_embedding_activation_buffer( @@ -2450,4 +2488,9 @@ def enable_grad_sync(): ): create_cudagraphs() + if config.sequence_packing and not forward_only: + forward_data_store.append( + [num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch] + ) + return forward_data_store diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0c5309a5876..d292be7f4a6 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -762,6 +762,11 @@ def forward( (Tuple[Tensor, Tensor]) Attention output and bias. """ + # here we need to set the right cp group for dynamic-cp + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + assert packed_seq_params.cp_group is not None, "cp_group must be set in dynamic-cp mode" + self.pg_collection.cp = packed_seq_params.cp_group + # Check if we need to skip RoPE # no_rope is 0-indexed array and self.layer_number is 1-indexed no_rope = ( diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index ed90fdffa97..2cf94320baa 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -558,8 +558,8 @@ def get_query_key_value_tensors( if packed_seq_params is not None: assert ( packed_seq_params.local_cp_size is None - ), "hybrid_context_parallel is not supported with MLA yet and is planned for future. \ - Please disable hybrid_context_parallel." + ), "dynamic_context_parallel is not supported with MLA yet and is planned for future. \ + Please disable dynamic_context_parallel." inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 62ce07586be..cb51d791fea 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -73,15 +73,6 @@ logger = logging.getLogger(__name__) -try: - # Register the TE CUDA kernels - import transformer_engine # pylint: disable=unused-import - - # Alias the PyTorch wrapper so we can call tex.* APIs - import transformer_engine_torch as tex -except ImportError: - # TE isn’t installed or the torch wrapper is missing - tex = None try: _torch_version = PkgVersion(torch.__version__) @@ -2094,103 +2085,6 @@ def get_batch_on_this_cp_rank( return batch -def get_thd_batch_on_this_cp_rank( - batch: Dict[str, Any], - cu_seqlens: torch.Tensor, - cu_seqlens_padded: torch.Tensor, - max_seqlen: torch.Tensor, - cp_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Slice each sub-sample in a packed sample batch input along - sequence dimension into multiple chunks, which are parallelized - across GPUs in a context parallel group. - """ - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=int(max_seqlen[0].item()), - max_seqlen_kv=int(max_seqlen[0].item()), - ) - - if cp_group is not None: - cp_size = get_pg_size(cp_group) - cp_rank = get_pg_rank(cp_group) - else: - cp_size = parallel_state.get_context_parallel_world_size() - cp_rank = parallel_state.get_context_parallel_rank() - if cp_size > 1: # slice batch along sequence dimension for context parallelism - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) - index = tex.thd_get_partitioned_indices( - cu_seqlens_padded, batch['tokens'].size(1), cp_size, cp_rank - ) - for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'}: - continue - batch[key] = data.index_select(1, index) - - return batch, packed_seq_params - - -################################ -### hybrid context parallel ### -################################ - - -def get_batch_on_this_hybrid_cp_rank( - batch: Dict[str, Any], - local_cp_size: int, - cp_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Slice batch input along sequence dimension into multiple chunks, - which are parallelized across GPUs in a context parallel group. - """ - assert local_cp_size is not None - if cp_group is None: - # Get the local cp group required for as defined by the HybridCPDataLoaderWrapper - if local_cp_size > 1: - cp_group = parallel_state.get_hybrid_data_context_parallel_groups( - group_size=local_cp_size - ) - else: - # If cp group is provided, it must match the local cp size - # as defined by the HybridCPDataLoaderWrapper - assert cp_group.size() == local_cp_size - - # Convert [seqlen] to [1, seqlen] similar to default collate_fn - # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn - for key, data in batch.items(): - if key in ['attention_mask']: - continue - batch[key] = torch.stack([data], 0) - sample_length = batch['tokens'].shape[1] - # TODO(pmannan): Take care of padding tokens here if not divisible by cp_size*2 - # Create packed_seq_params for SBHD format with cp group information. - packed_seq_params = PackedSeqParams( - qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - max_seqlen_q=sample_length, - max_seqlen_kv=sample_length, - local_cp_size=local_cp_size, - cp_group=cp_group, - ) - - if cp_group is not None and cp_group.size() > 1: - # When using hybrid_context_parallel, each sub-sample of a packed sample is - # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank(batch, cp_group) - - return batch, packed_seq_params - - ###################### ### NVTX profiling ### ###################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9aba3a7cb8e..eb74b53ad5d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -810,7 +810,13 @@ def validate_args(args, defaults={}): # across batches/microbatches. Due to additional communication overhead # during pipeline parallelism, it should not be set if sequence length # is constant during training. - args.variable_seq_lengths = False + if args.sequence_packing: + args.variable_seq_lengths = True + # TODO(tailaim): add support for other dispatcher types + print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") + args.moe_token_dispatcher_type = "alltoall" + else: + args.variable_seq_lengths = False # Iteration-based training. if args.train_iters: @@ -963,12 +969,31 @@ def validate_args(args, defaults={}): if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - if args.hybrid_context_parallel: - assert not args.pipeline_model_parallel_size > 1, 'Hybrid context parallelism not supported with pipeline parallelism' - assert not args.enable_cuda_graph, 'Hybrid context parallelism not supported with CUDA Graph' - assert not args.use_megatron_fsdp, 'Hybrid context parallelism not supported with Megatron FSDP' - assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' - assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + if args.dynamic_context_parallel: + assert not (args.pipeline_model_parallel_size > 1 and args.use_megatron_fsdp), \ + 'Dynamic context parallelism not supported with pipeline parallelism when using FSDP' + assert not args.enable_cuda_graph, 'Dynamic context parallelism not supported with CUDA Graph' + assert args.dataloader_type == 'single', 'Dynamic context parallelism only supported with single dataloader type' + assert args.calculate_per_token_loss, 'Dynamic context parallelism must be used with --calculate-per-token-loss' + assert args.context_parallel_size == 1, 'context parallel size must be 1 for dynamic context parallelism' + + if args.sequence_packing: + assert not args.create_attention_mask_in_dataloader, \ + 'Sequence packing does not support create_attention_mask_in_dataloader. ' \ + 'Please set --no-create-attention-mask-in-dataloader' + # Validate that packed sequence buffer is large enough for single sequences + if args.dynamic_context_parallel: + # packed_buffer_size = hdp_size * max_seqlen_per_rank >= single_seq_max_len + hdp_size = args.world_size // (args.tensor_model_parallel_size * args.pipeline_model_parallel_size) + assert hdp_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({hdp_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + else: + # packed_buffer_size = cp_size * max_seqlen_per_rank >= single_seq_max_len + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled @@ -2896,13 +2921,25 @@ def _add_distributed_args(parser): '--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus ' 'forms the first level of cp groups and the cp ranks with the same odevity ' 'forms the second level of cp groups.') - group.add_argument('--max-seqlen-per-cp-rank', type=int, default=None, + group.add_argument('--max-seqlen-per-dp-cp-rank', type=int, default=None, help='Maximum sequence length per CP rank. This is used to calculate the ' 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') - group.add_argument('--hybrid-context-parallel', action='store_true', default=False, - help='Enables hybrid context parallel. This is used to balance the workload ' + group.add_argument('--dynamic-context-parallel', action='store_true', default=False, + help='Enables dynamic context parallel. This is used to balance the workload ' 'of each CP rank when we use packed samples with variable sequence lengths. ' - 'Requires --max-seqlen-per-cp-rank to be set.') + 'Requires --max-seqlen-per-dp-cp-rank to be set.') + group.add_argument('--min-dynamic-context-parallel-size', type=int, default=1, + help='Minimum size of the dynamic context parallel groups.') + group.add_argument('--sequence-packing-scheduler', type=str, default=None, + choices=['default_dynamic_cp', 'empty_scheduler_with_packing', 'empty_scheduler_no_packing', 'naive_sequence_packing'], + help='Scheduler for sequence packing and dynamic context parallel. ' + 'naive_sequence_packing: default naive sequence packing scheduler(just THD, no Dynamic-CP, this ' + 'is just for comparison with default Dynamic-CP scheduler, not recommended for production) ' + 'default_dynamic_cp: default dynamic-cp scheduler for dynamic context parallel provided by MCore. ' + 'empty_scheduler_with_packing: scheduling is already handled by the data sampler, ' + 'this scheduler only performs packing. ' + 'empty_scheduler_no_packing: scheduling and packing are already handled by the data sampler, ' + 'this scheduler only returns the batch.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' @@ -3629,4 +3666,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sequence-packing', action='store_true', + help='use sequence packing(thd format) for training') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') return parser diff --git a/megatron/training/datasets/data_samplers.py b/megatron/training/datasets/data_samplers.py index d33250520dd..6433e751d8e 100644 --- a/megatron/training/datasets/data_samplers.py +++ b/megatron/training/datasets/data_samplers.py @@ -39,8 +39,8 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_size=mpu.get_data_parallel_world_size(), ) elif args.dataloader_type == 'single': - if args.hybrid_context_parallel: - batch_sampler = HybridCPMegatronPretrainingSampler( + if args.sequence_packing: + batch_sampler = MegatronSequencePackingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, @@ -79,7 +79,7 @@ def worker_init_fn(_): worker_init_fn if args.exit_signal_handler and args.num_workers > 0 else None ) # Torch dataloader. - if args.hybrid_context_parallel: + if args.sequence_packing: extra_kwargs = {"collate_fn": lambda x: x,} else: extra_kwargs = {} @@ -161,11 +161,11 @@ def __iter__(self): start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] -class HybridCPMegatronPretrainingSampler(MegatronPretrainingSampler): +class MegatronSequencePackingSampler(MegatronPretrainingSampler): """ - Data sampler for hybrid context parallel (Hybrid CP) format. + Data sampler for sequence packing. This data sampler pulls in the entire global batch at once across all data parallel ranks. - This helps provide the Hybrid CP Dataloader Wrapper to schedule and load balance sub-samples + This helps provide the sequence packing Dataloader Wrapper to schedule and load balance sub-samples of the entire global batch. """ diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index e4d8a6faf24..aa74c797673 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -1,8 +1,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig @@ -56,6 +59,8 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -68,8 +73,38 @@ def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowL def __len__(self) -> int: return self.num_samples - def __getitem__(self, idx: int) -> Dict[str, Any]: + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if dynamic_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.dynamic_context_parallel: + # Dynamic CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + + def get_padding_size( + self, + seq_len: int, + ) -> int: + seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor + assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ + f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ + {self.config.context_parallel_size-1} will have no valid tokens" + return seq_len_padded + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing tokenizer = self.config.tokenizer max_seq_len = self.config.sequence_length @@ -84,9 +119,13 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: tokens = tokens[: max_seq_len - force_eod_length] target = target[: max_seq_len - force_eod_length] - # padding + # if use sequence packing, pad according to get_padding_size + # else pad to max_seq_len num_tokens = len(tokens) + force_eod_length - padding_len = max_seq_len - num_tokens + if sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens assert padding_len >= 0 filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) @@ -98,9 +137,10 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: tokens = tokens[:-1].contiguous() target = target[1:].contiguous() + seq_len = tokens.numel() loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( - max_seq_len, target, tokenizer.pad + seq_len, target, tokenizer.pad ) if self.config.create_attention_mask: @@ -119,6 +159,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: 'position_ids': position_ids, } + if sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) + ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) + return ret def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): @@ -136,7 +181,7 @@ def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): if self.config.create_attention_mask: attention_mask = torch.tril( - torch.ones((seq_length, seq_length), device=data.device) + torch.ones((max_seq_len, max_seq_len), device=target.device) ).unsqueeze(0) # Convert attention mask to binary: attention_mask = attention_mask < 0.5 @@ -144,3 +189,138 @@ def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): attention_mask = None return loss_mask, position_ids, attention_mask + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mock_config (dict): The config for mock dataset. + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. + + + def __init__(self, config: Dict) -> None: + np.random.seed(self.seed) + # either choose to load sequence lengths from external file, or generate random sequence lengths + + assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" + + if config["mode"] == "file": + self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + self.size = len(self.sequence_lengths) + elif config["mode"] == "distribution": + min_seq_len = config["min_seq_len"] + max_seq_len = config["max_seq_len"] + mean_seq_len = config["mean_seq_len"] + if config["type"] == "lognormal": + lognormal_sigma = config["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + else: + raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> List[np.ndarray]: + length = self.sequence_lengths[idx % self.size] + # the length of sample is 'length', but only length-1 elements are generated here, + # because an eod token will be appended at the end later in SFTDataset + sample = np.arange(1, length, dtype=np.int64) + return sample +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(mock_config) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer + max_seq_len = self.config.sequence_length + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + target = np.array(tokens, dtype=np.int64) + + force_eod_length = int(tokenizer.force_eod) + + if len(tokens) > max_seq_len - force_eod_length: + # cut the right side + tokens = tokens[: max_seq_len - force_eod_length] + target = target[: max_seq_len - force_eod_length] + # tokens = tokens[(-max_seq_len + force_eod_length):] + # target = target[(-max_seq_len + force_eod_length):] + + # padding + num_tokens = len(tokens) + force_eod_length + if sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens + assert padding_len >= 0 + filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) + + tokens = np.array(tokens.tolist() + filler, dtype=np.int64) + target = np.array(target.tolist() + filler, dtype=np.int64) + + tokens = torch.tensor(tokens) + target = torch.tensor(target) + + tokens = tokens[:-1].contiguous() + target = target[1:].contiguous() + seq_len = tokens.numel() + + loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( + seq_len, target, tokenizer.pad + ) + + if self.config.create_attention_mask: + ret = { + 'tokens': tokens, + 'labels': target, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + ret = { + 'tokens': tokens, + 'labels': target, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + if sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) + ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) + + return ret diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 1a119b127e4..d7aeddb3d88 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -371,7 +371,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s use_sharp=args.use_sharp, context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, - hybrid_context_parallel=args.hybrid_context_parallel, + dynamic_context_parallel=args.dynamic_context_parallel, expert_model_parallel_size=args.expert_model_parallel_size, num_distributed_optimizer_instances=args.num_distributed_optimizer_instances, expert_tensor_parallel_size=args.expert_tensor_parallel_size, @@ -383,6 +383,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s create_gloo_process_groups=args.enable_gloo_process_groups, high_priority_stream_groups=args.high_priority_stream_groups, sharp_enabled_group=args.sharp_enabled_group, + min_dynamic_context_parallel_size=args.min_dynamic_context_parallel_size, ) if args.rank == 0: print( diff --git a/megatron/training/tokenizer/sft_tokenizer.py b/megatron/training/tokenizer/sft_tokenizer.py index f525352e892..5801dec53f6 100644 --- a/megatron/training/tokenizer/sft_tokenizer.py +++ b/megatron/training/tokenizer/sft_tokenizer.py @@ -62,7 +62,9 @@ def __init__( raise NotImplementedError("unknown SFT prompt format", prompt_format) self._prompt_format = prompt_format - + if self._prompt_config.pad_token_id is None: + self._prompt_config.pad_token_id = self._tokenizer.eos_token_id - 1 + print(f"pad token id is not set, set to (eos_token_id - 1): {self._prompt_config.pad_token_id} for {prompt_format}") def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool @@ -179,6 +181,11 @@ def bos(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id + + @property + def eos(self): + """End of sentence token ID.""" + return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 5371c0d318e..3f367a3f7a4 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -871,6 +871,14 @@ def eos(self): def additional_special_tokens_ids(self): return None + @property + def force_eod(self): + """To force an EOD at the end of every data sample in SFT.""" + return True + + @property + def pad(self): + return self._eod_id - 1 class _NullMultimodalTokenizer(MegatronLegacyTokenizer): def __init__(self, vocab_size, image_token=None, image_token_id=None): diff --git a/megatron/training/training.py b/megatron/training/training.py index 5b171821497..5373de7c808 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -105,7 +105,6 @@ from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader -from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics @@ -178,7 +177,7 @@ def print_datetime(string): print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -194,44 +193,42 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(num_total_tokens_this_global_batch, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * num_total_tokens_this_global_batch * hidden_size**2 - def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + def moe_layer_flops(num_total_tokens_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 if moe_latent_size is None: - routed_flops = (4 * batch_size * seq_len * hidden_size * + routed_flops = (4 * num_total_tokens_this_global_batch * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) else: # Routed experts run on moe_latent_size. - routed_flops = (4 * batch_size * seq_len * moe_latent_size * + routed_flops = (4 * num_total_tokens_this_global_batch * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) # Up proj and down proj. - routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) - shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + routed_flops += (4 * num_total_tokens_this_global_batch * hidden_size * moe_latent_size) + shared_flops = 4 * num_total_tokens_this_global_batch * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * num_total_tokens_this_global_batch + (hidden_size * (g / num_heads)) * num_total_tokens_this_global_batch + (sequence_square_sum_this_global_batch / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(num_total_tokens_this_global_batch, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -244,16 +241,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * num_total_tokens_this_global_batch * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * num_total_tokens_this_global_batch * d_in * state_dim) # scan + + (2 * num_total_tokens_this_global_batch * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -265,17 +261,17 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(num_total_tokens_this_global_batch, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(num_total_tokens_this_global_batch, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + num_moe_layers * moe_layer_flops(num_total_tokens_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size, swiglu) + - (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation + (2 * num_total_tokens_this_global_batch * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 @@ -348,13 +344,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + num_total_tokens_this_global_batch: total number of tokens in this global batch + sequence_square_sum_this_global_batch: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * num_total_tokens_this_global_batch * h^2 + attn: 2 * sequence_square_sum_this_global_batch * h + attn over value: 2 * sequence_square_sum_this_global_batch * h + oproj: 2 * num_total_tokens_this_global_batch * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -375,7 +376,7 @@ def transformer_flops(): standard_self_attn_term = ( 3 * 2 # fwd(1) + bwd(2) *FMA - * ( + * ( num_total_tokens_this_global_batch * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -387,12 +388,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + sequence_square_sum_this_global_batch * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + sequence_square_sum_this_global_batch * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -404,17 +405,17 @@ def transformer_flops(): standard_self_attn_term = ( 3 * 2 # fwd(1) + bwd(2) *FMA - * ( + * ( num_total_tokens_this_global_batch *( ## qkv proj args.hidden_size - * (query_projection_size + key_projection_size + value_projection_size) + * (query_projection_size + key_projection_size + value_projection_size)) ## core attention + query_projection_size - * args.seq_length + * sequence_square_sum_this_global_batch / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + num_total_tokens_this_global_batch * query_projection_size * args.hidden_size ) ) @@ -487,8 +488,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + num_total_tokens_this_global_batch * ( # MLP expansion_factor @@ -505,8 +505,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * gated_linear_multiplier) * (num_moe_layers / num_layers) ) - # Self Attention - + self_attn_term # MTP norms and proj + 3 * 2 @@ -520,6 +518,10 @@ def transformer_flops(): # Logit. + 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1) ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -530,8 +532,8 @@ def transformer_flops(): # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + num_total_tokens_this_global_batch=num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch=sequence_square_sum_this_global_batch, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1477,9 +1479,16 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_only=False, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, ) + + if args.sequence_packing: + num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch = losses_reduced.pop() + else: + sequence_square_sum_this_global_batch = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + num_total_tokens_this_global_batch = args.seq_length * args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch, 0 # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1559,8 +1568,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch def training_log( @@ -1575,6 +1586,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, pg_collection=None, ): """Log training information such as losses, timing, ....""" @@ -1797,7 +1810,7 @@ def training_log( elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2243,9 +2256,6 @@ def train( energy_monitor = get_energy_monitor() one_logger = get_one_logger() - if args.hybrid_context_parallel: - train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config)) - if args.run_workload_inspector_server: try: from workload_inspector.utils.webserver import run_server @@ -2499,6 +2509,8 @@ def get_e2e_base_metrics(): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: + # TODO(tailaim): this need to be modified + assert not args.sequence_packing, "Sequence packing is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: @@ -2534,6 +2546,8 @@ def get_e2e_base_metrics(): grad_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func ) @@ -2618,7 +2632,7 @@ def get_e2e_base_metrics(): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -2649,6 +2663,8 @@ def get_e2e_base_metrics(): params_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, pg_collection=model_pg_collection, ) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4730a525271..52a3bf36d88 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -541,58 +541,19 @@ def _broadcast(item): else data["attention_mask"].cuda(non_blocking=True) ), 'position_ids': data["position_ids"].cuda(non_blocking=True), - 'cu_seqlens': ( - None - if "cu_seqlens" not in data - else data["cu_seqlens"].cuda(non_blocking=True) - ), - 'max_seqlen': ( - None - if "max_seqlen" not in data - else data["max_seqlen"].cuda(non_blocking=True) - ), - 'local_cp_size': ( - None - if "local_cp_size" not in data - else data["local_cp_size"].cuda(non_blocking=True) - ), } - def _broadcast_cu_seqlens(cu_seqlens): - dev = torch.cuda.current_device() - n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) - _broadcast(n_tensor) - - if n == 0: - buf = torch.empty(0, dtype=torch.int32, device=dev) - else: - assert isinstance(cu_seqlens, torch.Tensor) - assert cu_seqlens.dtype == torch.int32 - assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing" - buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous() - _broadcast(buf) - - if args.hybrid_context_parallel: - seq_len = torch.tensor(batch['tokens'].shape[0], dtype=torch.int32, device=torch.cuda.current_device()) - _broadcast(seq_len) - if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) - _broadcast_cu_seqlens(batch['cu_seqlens']) - _broadcast(batch['max_seqlen']) - _broadcast(batch['local_cp_size']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) - _broadcast_cu_seqlens(batch['cu_seqlens']) - _broadcast(batch['max_seqlen']) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -603,79 +564,42 @@ def _broadcast_cu_seqlens(cu_seqlens): _broadcast(batch['attention_mask']) else: - if args.hybrid_context_parallel: - seq_len = torch.tensor(0, dtype=torch.int32, device=torch.cuda.current_device()) - _broadcast(seq_len) - shape = (seq_len.item()) - else: - shape = (args.micro_batch_size, args.seq_length) - + tokens = torch.empty( - shape, + (args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device(), ) labels = torch.empty( - shape, + (args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device(), ) loss_mask = torch.empty( - shape, + (args.micro_batch_size, args.seq_length), dtype=torch.float32, device=torch.cuda.current_device(), ) if args.create_attention_mask_in_dataloader: - shape_attention_mask = (args.micro_batch_size, 1, args.seq_length, args.seq_length) if not args.hybrid_context_parallel else (1, 1, shape[0], shape[0]) attention_mask = torch.empty( - shape_attention_mask, + (args.micro_batch_size, 1, args.seq_length, args.seq_length), dtype=torch.bool, device=torch.cuda.current_device(), ) else: attention_mask = None position_ids = torch.empty( - shape, + (args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device(), ) - cu_seqlens = None - max_seqlen = torch.empty( - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) if args.hybrid_context_parallel else None - local_cp_size = torch.empty( - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) if args.hybrid_context_parallel else None - - def _broadcast_cu_seqlens(): - dev = torch.cuda.current_device() - - n = torch.empty((), dtype=torch.int64, device=dev) - _broadcast(n) - n = int(n.item()) - - if n == 0: - cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev) - else: - cu_seqlens = torch.empty((args.micro_batch_size, n), dtype=torch.int32, device=dev) - _broadcast(cu_seqlens) - - return cu_seqlens if n > 0 else None - if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) - cu_seqlens = _broadcast_cu_seqlens() - _broadcast(max_seqlen) - _broadcast(local_cp_size) elif mpu.is_pipeline_first_stage(): labels = None @@ -684,8 +608,6 @@ def _broadcast_cu_seqlens(): _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) - cu_seqlens = _broadcast_cu_seqlens() - _broadcast(max_seqlen) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -693,8 +615,7 @@ def _broadcast_cu_seqlens(): # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. tokens = None position_ids = None - cu_seqlens = None - max_seqlen = None + _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) @@ -705,9 +626,6 @@ def _broadcast_cu_seqlens(): 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids, - 'cu_seqlens': cu_seqlens, - 'max_seqlen': max_seqlen, - 'local_cp_size': local_cp_size, } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index cfb5e1b5f1f..eec02b8a78d 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -11,10 +11,11 @@ from megatron.core import parallel_state from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_hybrid_cp_rank, StragglerDetector +from megatron.core.utils import get_attr_wrapped_model, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.core.transformer.multi_token_prediction import mtp_on_this_rank, get_mtp_ranks from megatron.training.arguments import core_transformer_config_from_args @@ -27,6 +28,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -44,6 +46,15 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) + + if args.sequence_packing: + return get_batch_on_this_rank_for_sequence_packing( + data_iterator, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + dynamic_context_parallel=args.dynamic_context_parallel, + ) + # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): @@ -54,24 +65,8 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): data_iterator, mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage) ) - - cu_seqlens = batch.pop('cu_seqlens', None) - cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) - max_seqlen = batch.pop('max_seqlen', None) - local_cp_size = batch.pop('local_cp_size', None) - if local_cp_size is not None: - local_cp_size = int(local_cp_size.item()) - - if cu_seqlens is None and local_cp_size is None: - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore - packed_seq_params = None - elif local_cp_size is None: # Packed THD format - assert max_seqlen.dim() == 1 - batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen) - else: # Hybrid CP format - batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size) - + batch = get_batch_on_this_cp_rank(batch) + packed_seq_params = None return (*batch.values(), packed_seq_params) @@ -222,7 +217,9 @@ def core_gpt_dataset_config_from_args(args): "context_parallel_size": args.context_parallel_size, "data_parallel_size": args.data_parallel_size, "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, - "hybrid_context_parallel": args.hybrid_context_parallel, + "dynamic_context_parallel": args.dynamic_context_parallel, + "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, + "sequence_packing": args.sequence_packing, } # add FIM args to the config @@ -260,7 +257,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset diff --git a/pretrain_mamba.py b/pretrain_mamba.py index ca2008620be..c7ffd7a7d11 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -49,7 +49,7 @@ def get_batch(data_iterator, vp_stage=None): cu_seqlens = batch.pop('cu_seqlens', None) cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) max_seqlen = batch.pop('max_seqlen', None) - # Support for Hybrid Context Parallel (Unused in this script) + # Support for Dynamic Context Parallel (Unused in this script) local_cp_size = batch.pop('local_cp_size', None) # slice batch along sequence dimension for context parallelism diff --git a/tests/unit_tests/context_parallel/test_packing_and_hybrid_cp.py b/tests/unit_tests/context_parallel/test_packing_and_hybrid_cp.py new file mode 100644 index 00000000000..0ee46186993 --- /dev/null +++ b/tests/unit_tests/context_parallel/test_packing_and_hybrid_cp.py @@ -0,0 +1,1016 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import os +from functools import partial +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import numpy +import torch.distributed + +from megatron.core import mpu, parallel_state +from megatron.core.datasets.data_schedule import PackingScheduler, wrap_dataloader, get_batch_on_this_rank_for_sequence_packing +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, +) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.num_microbatches_calculator import ( + init_num_microbatches_calculator, + unset_num_microbatches_calculator, +) +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.multi_token_prediction import mtp_on_this_rank +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.training.global_vars import set_args, set_global_variables, unset_global_variables +from tests.unit_tests.dist_checkpointing import TempNamedDir +from tests.unit_tests.dist_checkpointing.models.common import ( + common_test_parallel_reconfiguration_e2e, +) +from tests.unit_tests.test_utilities import Utils + + +@pytest.fixture +def create_args(): + """Setup dummy args.""" + args = SimpleNamespace() + args.finetune = False + args.non_persistent_global_ckpt_dir = None + args.non_persistent_ckpt_type = None + args.non_persistent_save_interval = None + args.exit_on_missing_checkpoint = True + args.async_save = False + args.data_parallel_random_init = False + args.log_progress = False + args.ckpt_fully_parallel_save = False + args.ckpt_fully_parallel_load = False + args.auto_detect_ckpt_format = False + args.retro_add_retriever = False + args.ckpt_convert_update_legacy_dist_opt_format = False + args.ckpt_step = None + args.use_dist_ckpt = True + args.consumed_train_samples = 0 + args.skipped_train_samples = 0 + args.consumed_valid_samples = 0 + args.vocab_file = None + args.add_position_embedding = False + args.ckpt_assume_constant_structure = True + args.dist_ckpt_strictness = "assume_ok_unexpected" + args.fp16 = False + args.bf16 = True + args.no_save_optim = True + args.no_save_rng = True + args.no_load_optim = True + args.no_load_rng = True + args.use_distributed_optimizer = True + args.use_megatron_fsdp = False + args.dist_ckpt_save_pre_mcore_014 = False + args.dist_ckpt_optim_fully_reshardable = False + args.distrib_optim_fully_reshardable_mem_efficient = False + args.data_path = None + args.mock_data = True + args.data_args_path = None + args.train_data_path, args.valid_data_path, args.test_data_path = None, None, None + args.per_split_data_args_path = None + args.rank = int(os.getenv('RANK', '0')) + # Model config + args.num_layers = 8 + args.hidden_size = 128 + args.num_attention_heads = 8 + # Ckpt format + args.ckpt_format = "torch_dist" + args.tokenizer_type = 'NullTokenizer' + args.vocab_size = 16384 + args.make_vocab_size_divisible_by = 128 + args.padded_vocab_size = 16512 + args.reset_position_ids = False + args.reset_attention_mask = False + args.eod_mask_loss = False + args.multi_latent_attention = False + args.heterogeneous_layers_config_path = None + args.no_persist_layer_norm = True + args.apply_layernorm_1p = False + args.norm_epsilon = 1e-6 + args.params_dtype = torch.bfloat16 + args.overlap_p2p_comm = True + args.rotary_interleaved = False + args.decoder_first_pipeline_num_layers = None + args.decoder_last_pipeline_num_layers = None + args.fp8_param_gather = False + args.swiglu = True + args.bias_swiglu_fusion = False + args.squared_relu = False + args.init_method_xavier_uniform = False + args.quick_geglu = False + args.group_query_attention = True + args.config_logger_dir = None + args.rope_type = None + args.is_hybrid_model = False + args.num_query_groups = 4 + args.cp_comm_type = ['p2p'] + args.seed = 123 + args.rampup_batch_size = None + args.global_batch_size = 256 + args.micro_batch_size = 1 + args.decrease_batch_size_if_needed = False + args.enable_one_logger = True + args.one_logger_async = False + args.adlr_autoresume = False + args.adlr_autoresume_interval = 1000 + args.timing_log_level = 0 + args.timing_log_option = "minmax" + args.enable_experimental = False + args.exit_signal_handler = False + args.disable_jit_fuser = False + args.one_logger_project = "megatron-lm" + args.one_logger_run_name = None + args.tensorboard_dir = None + args.tensorboard_queue_size = 1000 + args.wandb_project = "" + args.wandb_exp_name = "" + args.wandb_save_dir = "" + args.wandb_entity = "" + args.iteration = 0 + args.train_samples = 100000 + args.full_validation = False + args.train_iters = 100 + args.dataloader_type = "single" + args.eval_iters = 32 + args.eval_interval = 500 + args.save_interval = 500 + args.exit_interval = None + args.exit_duration_in_mins = None + args.legacy_tokenizer = True + args.split = "99,1,0" + args.multiple_validation_sets = False + args.num_dataset_builder_threads = 1 + args.num_workers = 2 + args.skip_train = False + args.data_cache_path = None + args.mmap_bin_files = False + args.object_storage_cache_path = None + args.mid_level_dataset_surplus = 0.005 + args.create_attention_mask_in_dataloader = False + args.sft_mock_dataset_config_json = None + args.sequence_packing_scheduler = None + args.check_for_nan_in_loss_and_grad = False + args.check_for_spiky_loss = False + args.sequence_parallel = False + args.untie_embeddings_and_output_weights = True + args.hidden_dropout = 0.0 + args.attention_dropout = 0.0 + args.moe_ffn_hidden_size = None + args.use_legacy_models = False + args.allow_ambiguous_pad_tokens = False + args.add_bias_linear = False + args.sft = True + args.overlap_moe_expert_parallel_comm = False + args.sft_mock_dataset_config_json = '{"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":8192,"mean_seq_len":4096,"lognormal_sigma":1.1}' + args.world_size = 8 + args.seq_length = 8192 + args.max_position_embeddings = 8192 + args.max_seqlen_per_dp_cp_rank = None + args.variable_seq_lengths = False + args.moe_token_dispatcher_type = "allgather" + args.moe_latent_size = None + args.te_precision_config_file = None + + yield args + + +def initialize_gpt_model( + args, + layer_spec_fn=gpt_te_spec, + virtual_pipeline_model_parallel_size=None, + is_moe=False, + with_mtp=False, + **config_kwargs, +): + torch.manual_seed(args.seed) + model_parallel_cuda_manual_seed(args.seed) + + # NOTE: This unit test uses TP/PP/CP (and optionally dynamic-CP). We must pass the + # model-parallel sizes into TransformerConfig; otherwise it defaults to cp=1 which + # breaks RoPE sharding (cp_group.size()>1 but config.context_parallel_size==1). + default_config_kwargs = dict( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + use_cpu_initialization=True, + pipeline_dtype=args.params_dtype, + bf16=args.bf16, + tensor_model_parallel_size=args.tensor_model_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + sequence_parallel=args.sequence_parallel, + dynamic_context_parallel=args.dynamic_context_parallel, + sequence_packing_scheduler=args.sequence_packing_scheduler, + sequence_packing=getattr(args, "sequence_packing", False), + max_seqlen_per_dp_cp_rank=getattr(args, "max_seqlen_per_dp_cp_rank", None), + virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, + hidden_dropout=args.hidden_dropout, + attention_dropout=args.attention_dropout, + mtp_num_layers=1 if with_mtp else None, + mtp_loss_scaling_factor=1.0 if with_mtp else None, + variable_seq_lengths=args.variable_seq_lengths, + moe_token_dispatcher_type=args.moe_token_dispatcher_type, + ) + default_config_kwargs.update(**config_kwargs) + + transformer_config = TransformerConfig(**default_config_kwargs) + if is_moe: + # transformer_config.moe_layer_freq = [0, 1, 1, 1, 1, 0, 1, 0] + transformer_config.moe_ffn_hidden_size = args.moe_ffn_hidden_size + transformer_config.num_moe_experts = args.num_experts + transformer_config.add_bias_linear = args.add_bias_linear + + model = [] + for i in range(virtual_pipeline_model_parallel_size or 1): + if is_moe: + layer_spec = layer_spec_fn(transformer_config, use_transformer_engine=True, vp_stage=i) + else: + layer_spec = layer_spec_fn() + + if with_mtp and mtp_on_this_rank(transformer_config, ignore_virtual=False, vp_stage=i): + if is_moe: + transformer_layer_spec_for_mtp = gpt_te_spec(transformer_config) + else: + transformer_layer_spec_for_mtp = layer_spec + mtp_block_spec = get_gpt_mtp_block_spec( + transformer_config, + transformer_layer_spec_for_mtp, + use_transformer_engine=True, + vp_stage=i, + ) + else: + mtp_block_spec = None + + # print("========================") + # print("[DEBUG] mtp_block_spec is ", mtp_block_spec) + # exit() + pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + this_model = ( + GPTModel( + config=transformer_config, + transformer_layer_spec=layer_spec, + vocab_size=args.padded_vocab_size, + pre_process=pre_process, + post_process=post_process, + position_embedding_type="rope", + vp_stage=i, + mtp_block_spec=mtp_block_spec, + share_embeddings_and_output_weights=False, + max_sequence_length=args.seq_length, + ) + .bfloat16() + .cuda() + ) + this_model.model_type = ModelType.encoder_or_decoder + model.append(this_model) + + if virtual_pipeline_model_parallel_size is None: + model = model[0] + return model + + +def get_data_iterator(args): + """ + Get the data iterator for the test. + + Args: + args: args namespace + """ + from megatron.core.datasets.blended_megatron_dataset_builder import ( + BlendedMegatronDatasetBuilder, + ) + from megatron.core.datasets.gpt_dataset import GPTDatasetConfig + from megatron.training import get_tokenizer + from megatron.training.datasets.sft_dataset import MockSFTDataset, MockSFTLowLevelDataset + from megatron.training.training import build_train_valid_test_data_iterators + from megatron.training.utils import ( + get_blend_and_blend_per_split + ) + from pretrain_gpt import is_dataset_built_on_rank + + blend, blend_per_split = get_blend_and_blend_per_split(args) + # rebuild_tokenizer(args) + tokenizer = get_tokenizer() + dataset_config = GPTDatasetConfig( + random_seed=123, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split='969, 30, 1', + tokenizer=tokenizer, + create_attention_mask=False, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + context_parallel_size=args.context_parallel_size, + data_parallel_size=args.data_parallel_size, + sequence_parallel_size=args.tensor_model_parallel_size, + dynamic_context_parallel=args.dynamic_context_parallel, + sft_mock_dataset_config_json=args.sft_mock_dataset_config_json, + sequence_packing=args.sequence_packing, + ) + train_ds, _, _ = BlendedMegatronDatasetBuilder( + MockSFTDataset, + [100000, 2560, 2560], + partial(is_dataset_built_on_rank, vp_stage=None), + dataset_config, + ).build() + + is_tp_first = parallel_state.get_tensor_model_parallel_rank() == 0 + is_pp_first = parallel_state.get_pipeline_model_parallel_rank() == 0 + is_pp_last = (parallel_state.get_pipeline_model_parallel_rank() == + parallel_state.get_pipeline_model_parallel_world_size() - 1) + + if is_tp_first and (is_pp_first or is_pp_last): + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = args.data_parallel_size // args.context_parallel_size + num_microbatches = args.global_batch_size // dp_size // args.micro_batch_size + start_index = dp_rank * num_microbatches + end_index = start_index + num_microbatches + samples = [train_ds[i] for i in range(start_index, end_index)] + if args.sequence_packing: + data_iterator = RerunDataIterator(iter([samples])) + else: + for sample in samples: + sample['tokens'] = sample['tokens'].unsqueeze(0) + sample['labels'] = sample['labels'].unsqueeze(0) + sample['loss_mask'] = sample['loss_mask'].unsqueeze(0) + sample['position_ids'] = sample['position_ids'].unsqueeze(0) + data_iterator = RerunDataIterator(iter(samples)) + else: + data_iterator = None + + if (args.virtual_pipeline_model_parallel_size is not None and + args.virtual_pipeline_model_parallel_size > 1): + vpp_size = args.virtual_pipeline_model_parallel_size + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp_size - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp_size - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp_size)] + + return data_iterator + +# Dense and MoE Models +@pytest.mark.parametrize( + ('tp_pp_cp_vpp', 'is_moe'), + [ + ((1, 2, 1, None), True), + ((1, 4, 1, None), True), + ((2, 2, 2, None), True), + ((1, 1, 2, None), True), + ((1, 2, 1, None), False), + ((1, 4, 1, None), False), + ((2, 2, 2, None), False), + ((1, 1, 2, None), False), + ], +) +def test_packing_and_dynamic_cp(create_args, tp_pp_cp_vpp, is_moe): + + def _compute_avg_loss(losses_list): + """计算所有 micro-batches 的平均 loss""" + total_loss_sum = 0.0 + total_tokens = 0.0 + for loss_dict in losses_list: + if isinstance(loss_dict, dict) and 'lm loss' in loss_dict: + t = loss_dict['lm loss'] + if torch.is_tensor(t) and t.dim() == 1 and t.numel() == 2: + total_loss_sum += t[0].item() + total_tokens += t[1].item() + if total_tokens > 0: + return total_loss_sum / total_tokens + return 0.0 + + args = create_args + losses_reduced_baseline, is_last_stage = dummy_forward_func( + args, + is_sequence_packing=False, + is_dynamic_context_parallel=False, + tp_pp_cp_vpp=tp_pp_cp_vpp, + is_moe=is_moe, + ) + losses_reduce_packing, _ = dummy_forward_func( + args, + is_sequence_packing=True, + is_dynamic_context_parallel=False, + tp_pp_cp_vpp=tp_pp_cp_vpp, + is_moe=is_moe, + ) + losses_reduced_hybrid, _ = dummy_forward_func( + args, + is_sequence_packing=True, + is_dynamic_context_parallel=True, + tp_pp_cp_vpp=tp_pp_cp_vpp, + is_moe=is_moe, + ) + if is_last_stage and torch.distributed.get_rank() == 0: + avg_baseline = _compute_avg_loss(losses_reduced_baseline) + avg_packing = _compute_avg_loss(losses_reduce_packing) + avg_hybrid = _compute_avg_loss(losses_reduced_hybrid) + print(f"avg_loss_baseline: {avg_baseline:.6f}") + print(f"avg_loss_packing: {avg_packing:.6f}") + print(f"avg_loss_hybrid: {avg_hybrid:.6f}") + + # NOTE: dummy_forward_func() destroys model-parallel groups before returning. + # So we must not query parallel_state after it returns. + if is_last_stage: + avg_baseline = _compute_avg_loss(losses_reduced_baseline) + avg_packing = _compute_avg_loss(losses_reduce_packing) + avg_hybrid = _compute_avg_loss(losses_reduced_hybrid) + + rtol = 1e-3 # 相对误差 0.1% + + # 相对误差: |a - b| / |b| < rtol + rel_err_packing = abs(avg_packing - avg_baseline) / abs(avg_baseline) if avg_baseline != 0 else 0 + rel_err_hybrid = abs(avg_hybrid - avg_baseline) / abs(avg_baseline) if avg_baseline != 0 else 0 + + assert rel_err_packing < rtol, \ + f"packing avg loss {avg_packing:.6f} vs baseline {avg_baseline:.6f}, rel_err={rel_err_packing:.6e}" + assert rel_err_hybrid < rtol, \ + f"hybrid avg loss {avg_hybrid:.6f} vs baseline {avg_baseline:.6f}, rel_err={rel_err_hybrid:.6e}" + + print("test_packing_and_dynamic_cp passed with tp_pp_cp_vpp: ", tp_pp_cp_vpp, "is_moe: ", is_moe) + + +def dummy_forward_func( + args, is_sequence_packing, is_dynamic_context_parallel, tp_pp_cp_vpp, is_moe +): + from megatron.core.pipeline_parallel import get_forward_backward_func + from pretrain_gpt import forward_step, get_batch + + args.sequence_packing = is_sequence_packing + args.dynamic_context_parallel = is_dynamic_context_parallel + + args.num_experts = 4 if is_moe else None + args.moe_ffn_hidden_size = 768 if is_moe else None + + def set_tp_pp_vpp(tp, pp, cp, vpp=None, destroy_first=True): + if destroy_first: + Utils.destroy_model_parallel() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.virtual_pipeline_model_parallel_size = vpp + args.data_parallel_size = 8 // (tp * pp) + # Dynamic-CP requires context_parallel_size == 1; CP is achieved via DPxCP hybrid groups. + args.context_parallel_size = 1 if args.dynamic_context_parallel else cp + if args.dynamic_context_parallel: + dp_cp_size = args.data_parallel_size * args.context_parallel_size + if dp_cp_size % 2 != 0: + pytest.skip( + "Dynamic context parallel requires an even dp-cp group size" + ) + if tp > 1: + args.sequence_parallel = True + Utils.initialize_model_parallel( + tp, + pp, + vpp, + context_parallel_size=args.context_parallel_size, + dynamic_context_parallel=args.dynamic_context_parallel, + min_dynamic_context_parallel_size=getattr(args, "min_dynamic_context_parallel_size", 1), + ) + + set_tp_pp_vpp(*tp_pp_cp_vpp) + if is_sequence_packing: + args.variable_seq_lengths = True + # TODO(tailaim): add support for other dispatcher types + print( + f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism" + ) + args.moe_token_dispatcher_type = "alltoall" + if is_dynamic_context_parallel: + args.max_seqlen_per_dp_cp_rank = args.seq_length // args.data_parallel_size + args.sequence_packing_scheduler = "default_dynamic_cp" + else: + args.max_seqlen_per_dp_cp_rank = args.seq_length // args.context_parallel_size + args.sequence_packing_scheduler = "naive_sequence_packing" + else: + args.sequence_packing_scheduler = None + + set_global_variables(args) + # set_args(args) + + layer_spec_fn = get_gpt_decoder_block_spec if is_moe else gpt_te_spec + model = initialize_gpt_model( + args, + layer_spec_fn=layer_spec_fn, + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + tensor_model_parallel_size=args.tensor_model_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, + is_moe=is_moe, + with_mtp=False, + ) + model = model if isinstance(model, list) else [model] + + data_iterator = get_data_iterator(args) + + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func=forward_step, + data_iterator=data_iterator, + model=model, + num_microbatches=args.global_batch_size + // args.data_parallel_size + * args.context_parallel_size + // args.micro_batch_size, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + forward_only=True, + ) + + # Capture pipeline stage info BEFORE destroying model-parallel state. + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + Utils.destroy_model_parallel() + unset_num_microbatches_calculator() + + unset_global_variables() + + return losses_reduced, is_last_stage + + +class MockVariableLengthSequencePackingDataIterator: + """ + Mock data iterator for testing get_batch_on_this_rank_for_sequence_packing. + + Generates variable-length (THD format) packed sequences with deterministic + data for verification across parallel ranks. + """ + + def __init__( + self, + total_seq_length: int, + sequence_lengths: list, + local_cp_size: int = None, + device: str = "cuda", + seed: int = 42, + ): + """ + Args: + total_seq_length: Total length of packed sequences + sequence_lengths: List of individual sequence lengths (variable-length). + If None, generates random variable lengths. + local_cp_size: Local CP size for dynamic context parallel + device: Device to create tensors on + seed: Random seed for reproducibility + """ + self.total_seq_length = total_seq_length + self.sequence_lengths = sequence_lengths + self.local_cp_size = local_cp_size + self.device = device + self.seed = seed + assert ( + sum(self.sequence_lengths) == total_seq_length + ), f"Sequence lengths sum {sum(self.sequence_lengths)} != total {total_seq_length}" + + def __iter__(self): + """Interface for the data iterator.""" + return self + + def __next__(self): + """Generate a mock batch with variable-length THD format.""" + dev = self.device + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + tokens = torch.randint(0, 16384, (self.total_seq_length,), dtype=torch.int64, device=dev) + + # Create position_ids that reset for each sequence (THD format) + position_ids = [] + for seq_len in self.sequence_lengths: + position_ids.extend(range(seq_len)) + position_ids = torch.tensor(position_ids, dtype=torch.int64, device=dev) + + # Labels are tokens shifted by 1 for easy verification + labels = tokens + 1 + + # Loss mask: 1.0 for all positions except padding (none here) + loss_mask = torch.ones(self.total_seq_length, dtype=torch.float32, device=dev) + + # Create cu_seqlens for variable-length packed sequences + cu_seqlens = [0] + for seq_len in self.sequence_lengths: + cu_seqlens.append(cu_seqlens[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=dev) + cu_seqlens_padded = cu_seqlens.clone() + + max_seqlen = torch.tensor([max(self.sequence_lengths)], dtype=torch.int32, device=dev) + + batch = { + "tokens": tokens, + "position_ids": position_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "max_seqlen": max_seqlen, + } + + if not ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + batch["tokens"] = None + batch["position_ids"] = None + batch["labels"] = None + batch["loss_mask"] = None + + if self.local_cp_size is not None: + batch["local_cp_size"] = torch.tensor( + [self.local_cp_size], dtype=torch.int32, device=dev + ) + + return batch + + +def _gather_tensor_from_tp_group(tensor): + """Gather tensors from all TP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + tp_size = parallel_state.get_tensor_model_parallel_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(tp_size)] + torch.distributed.all_gather( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +def _gather_tensor_from_all_ranks(tensor): + """Gather tensors from all PP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + if type(tensor) is int: + tensor = torch.tensor(tensor, dtype=torch.int32, device=torch.cuda.current_device()) + gathered = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return gathered + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "dynamic_cp"), + [ + (1, 1, 1, False), # Basic case: no parallelism + (2, 1, 1, False), # Tensor parallel only + (1, 2, 1, False), # Pipeline parallel only + (2, 2, 1, False), # TP + PP + (1, 1, 2, False), # CP only + (2, 1, 2, False), # TP + CP + (1, 2, 2, False), # PP + CP + (1, 4, 1, False), # Has middle pp stage + (1, 1, 1, True), # Dynamic CP enabled (CP=1 with hybrid groups) + (2, 1, 1, True), # TP + Dynamic CP + ], +) +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, dynamic_cp): + """ + Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. + + This test verifies: + 1. TP ranks: All ranks within a TP group receive identical data after broadcast + 2. PP ranks: Middle PP ranks have the same packed_seq_params as first/last stages + 3. CP ranks: Data is correctly partitioned with proper shape and values + 4. Variable-length (THD) format: Different sequence lengths are handled correctly + """ + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.dynamic_context_parallel = dynamic_cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + # Initialize model parallel + Utils.initialize_model_parallel( + tp, + pp, + None, + context_parallel_size=cp, + dynamic_context_parallel=dynamic_cp, + min_dynamic_context_parallel_size=1, + ) + + try: + # Create mock data iterator with variable-length sequences + # Only TP rank 0 needs the iterator; other TP ranks pass None + tp_rank = parallel_state.get_tensor_model_parallel_rank() + local_cp_size = 8 // (tp * pp) if dynamic_cp else None + + if tp_rank == 0: + # Use deterministic seed based on DP rank so same data within TP/PP/CP group + dp_rank = parallel_state.get_data_parallel_rank() + sequence_lengths = [1024, 2048, 512, 1536, 3072] + assert ( + sum(sequence_lengths) == args.seq_length + ), f"Sequence lengths sum {sum(sequence_lengths)} != total {args.seq_length}" + data_iterator = iter( + MockVariableLengthSequencePackingDataIterator( + total_seq_length=args.seq_length, + sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 + local_cp_size=local_cp_size, + seed=42 + dp_rank, # Same seed within PP/CP group + ) + ) + else: + # Non-TP-rank-0 ranks don't need the iterator + data_iterator = None + + # Call the function under test + result = get_batch_on_this_rank_for_sequence_packing( + data_iterator=data_iterator, + mtp_on_this_rank=False, + vp_stage=None, + dynamic_context_parallel=dynamic_cp, + ) + + # Unpack the result + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = result + + # Get parallel state info + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + is_first_stage = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + is_first_or_last = is_first_stage or is_last_stage + + # ===================================================================== + # TEST 1: Verify data based on pipeline stage + # ===================================================================== + if is_first_stage: + assert tokens is not None, "First stage should have tokens" + assert position_ids is not None, "First stage should have position_ids" + assert tokens.dim() == 2, "Tokens should be 2D (batch, seq)" + assert position_ids.dim() == 2, "Position IDs should be 2D (batch, seq)" + assert tokens.size(0) == 1, "batch should be 1 in THD format" + assert position_ids.size(0) == 1, "batch should be 1 in THD format" + else: + assert tokens is None, "Non-first stage should not have tokens" + assert position_ids is None, "Non-first stage should not have position_ids" + + if is_last_stage: + assert labels is not None, "Last stage should have labels" + assert loss_mask is not None, "Last stage should have loss_mask" + assert labels.dim() == 2, "Labels should be 2D (batch, seq)" + assert loss_mask.dim() == 2, "Loss mask should be 2D (batch, seq)" + assert labels.size(0) == 1, "batch should be 1 in THD format" + assert loss_mask.size(0) == 1, "batch should be 1 in THD format" + else: + assert labels is None, "Non-last stage should not have labels" + assert loss_mask is None, "Non-last stage should not have loss_mask" + + # ===================================================================== + # TEST 2: Verify all ranks have consistent packed_seq_params + # ===================================================================== + assert packed_seq_params is not None + assert packed_seq_params.qkv_format == "thd" + if dynamic_cp: + assert packed_seq_params.local_cp_size is not None + assert packed_seq_params.cp_group is not None + + test_keys = [ + "cu_seqlens_q", + "cu_seqlens_q_padded", + "max_seqlen_q", + "cu_seqlens_kv", + "cu_seqlens_kv_padded", + "max_seqlen_kv", + ] + if dynamic_cp: + test_keys.append("local_cp_size") + for key in test_keys: + tensor = getattr(packed_seq_params, key) + assert tensor is not None + gathered_tensor = _gather_tensor_from_all_ranks(tensor) + for i in range(1, len(gathered_tensor)): + assert torch.equal( + gathered_tensor[0], gathered_tensor[i] + ), f"Rank 0 and rank {i} have different {key}" + + # ===================================================================== + # TEST 3: Verify TP ranks receive identical data after broadcast + # ===================================================================== + if tp > 1: + test_tensors = [] + if is_first_stage: + test_tensors.extend([tokens, position_ids]) + if is_last_stage: + test_tensors.extend([labels, loss_mask]) + + for tensor in test_tensors: + gathered_tensors = _gather_tensor_from_tp_group(tensor) + for i in range(1, tp): + assert torch.equal( + gathered_tensors[0], gathered_tensors[i] + ), f"TP rank 0 and rank {i} have different data" + + # ===================================================================== + # TEST 4: Verify CP partitioning + # ===================================================================== + if cp > 1 or dynamic_cp: + if dynamic_cp: + assert packed_seq_params.local_cp_size is not None + cp_size = packed_seq_params.local_cp_size + assert packed_seq_params.cp_group == ( + parallel_state.get_dynamic_data_context_parallel_groups(group_size=cp_size) + ) + else: + cp_size = cp + + # With CP, the sequence should be partitioned + expected_seq_len = args.seq_length // cp_size + + if is_first_stage: + actual_seq_len = tokens.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned tokens have wrong shape: {actual_seq_len} != {expected_seq_len}" + + # Verify labels only if all CP ranks are at last stage + if is_last_stage: + actual_seq_len = labels.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned labels have wrong shape: {actual_seq_len} != {expected_seq_len}" + + finally: + Utils.destroy_model_parallel() + unset_global_variables() + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "vpp","scheduler_type"), + [ + (1, 1, 1, None, PackingScheduler.DEFAULT_DYNAMIC_CP), + (1, 1, 8, None, PackingScheduler.NAIVE_SEQUENCE_PACKING), + (2, 1, 1, None, PackingScheduler.DEFAULT_DYNAMIC_CP), + (2, 1, 4, None, PackingScheduler.NAIVE_SEQUENCE_PACKING), + (2, 4, 1, None, PackingScheduler.NAIVE_SEQUENCE_PACKING), + (2, 2, 1, None, PackingScheduler.DEFAULT_DYNAMIC_CP), + (2, 2, 1, None, PackingScheduler.NAIVE_SEQUENCE_PACKING), + (1, 4, 1, 4, PackingScheduler.DEFAULT_DYNAMIC_CP), + (1, 4, 1, 4, PackingScheduler.NAIVE_SEQUENCE_PACKING), + ], +) +def test_wrap_dataloader(tp, pp, cp, vpp, scheduler_type): + ''' + Test wrap_dataloader function with different scheduler types. + ''' + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + args.max_seqlen_per_dp_cp_rank = 8192 + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP: + args.dynamic_context_parallel = True + elif scheduler_type is PackingScheduler.NAIVE_SEQUENCE_PACKING: + args.dynamic_context_parallel = False + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + def _create_single_sample(seq_len): + # hard code the padding size to 16 + pad_size = 16 + seq_len_padded = ((seq_len + pad_size - 1) // pad_size) * pad_size + device = torch.device("cuda", torch.cuda.current_device()) + tokens = torch.randint(0, 16384, (seq_len_padded,), dtype=torch.int64, device=device) + labels = tokens + 1 + position_ids = torch.arange(seq_len_padded, dtype=torch.int64, device=device) + loss_mask = torch.ones(seq_len_padded, dtype=torch.float32, device=device) + loss_mask[0:seq_len] = 1 + loss_mask[seq_len:] = 0 + original_seq_len = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) + padded_seq_len = torch.tensor(seq_len_padded, dtype=torch.int32, device=tokens.device) + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + "original_seq_len": original_seq_len, + "padded_seq_len": padded_seq_len, + } + + # Initialize model parallel + Utils.initialize_model_parallel( + tp, + pp, + vpp, + context_parallel_size=cp, + dynamic_context_parallel=args.dynamic_context_parallel, + min_dynamic_context_parallel_size=1, + ) + + global_batch_size = 64 + micro_batch_size = 1 + import random + nums = [random.randint(2048, args.seq_length) for _ in range(global_batch_size)] # 64 sequences + + config = SimpleNamespace() + config.max_seqlen_per_dp_cp_rank = args.max_seqlen_per_dp_cp_rank + config.microbatch_group_size_per_vp_stage = pp + config.dynamic_context_parallel = args.dynamic_context_parallel + config.virtual_pipeline_model_parallel_size = vpp + + + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + is_pp_first = (pp_rank == 0) + is_pp_last = (pp_rank == pp - 1) + is_tp_first = (tp_rank == 0) + + if is_tp_first and (is_pp_first or is_pp_last): + num_per_dp = global_batch_size // dp_size // micro_batch_size + samples = [_create_single_sample(num) for num in nums[dp_rank*num_per_dp:(dp_rank+1)*num_per_dp]] + data_iterator = RerunDataIterator(iter([samples])) + else: + data_iterator = None + + if is_tp_first: + if vpp is not None and vpp > 1: + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp)] + try: + # Call the function under test + (new_data_iterator, num_micro_batches, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch) = wrap_dataloader(data_iterator, config, scheduler_type) + + # check the result + assert type(num_micro_batches) is int + assert type(num_total_tokens_this_global_batch) is float or type(num_total_tokens_this_global_batch) is numpy.float32 + assert type(sequence_square_sum_this_global_batch) is float or type(sequence_square_sum_this_global_batch) is numpy.float32 + + def _check_batch(batch_all, batch_keys): + for batch in batch_all: + assert set(batch.keys()) == set(batch_keys), f"batch keys: {set(batch.keys())} != {set(batch_keys)}" + for key in batch_keys: + assert batch[key] is not None + + # verify the result + if is_tp_first: + batch_keys = ["cu_seqlens","max_seqlen","cu_seqlens_padded"] + if scheduler_type is PackingScheduler.DEFAULT_DYNAMIC_CP: + batch_keys.append("local_cp_size") + if vpp is not None and vpp > 1: + if is_pp_first: + for data_iterator in new_data_iterator[1:]: + batch_all = [next(data_iterator) for _ in range(num_micro_batches)] + _check_batch(batch_all, batch_keys) + new_data_iterator = new_data_iterator[0] + elif is_pp_last: + for data_iterator in new_data_iterator[:-1]: + batch_all = [next(data_iterator) for _ in range(num_micro_batches)] + _check_batch(batch_all, batch_keys) + new_data_iterator = new_data_iterator[-1] + else: + for data_iterator in new_data_iterator: + batch_all = [next(data_iterator) for _ in range(num_micro_batches)] + _check_batch(batch_all, batch_keys) + new_data_iterator = new_data_iterator[0] + + batch_all = [next(new_data_iterator) for _ in range(num_micro_batches)] + if is_pp_first or is_pp_last: + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + + _check_batch(batch_all, batch_keys) + else: + if vpp is not None and vpp > 1: + assert type(new_data_iterator) is list and len(new_data_iterator) == vpp + for data_iterator in new_data_iterator: + assert data_iterator is None + else: + assert new_data_iterator is None + + finally: + if torch.distributed.get_rank() == 0: + print(f"rank:0, exit test_wrap_dataloader successfully with tp:{tp}, pp:{pp}, cp:{cp}, vpp:{vpp}, scheduler_type:{scheduler_type}",flush=True) + Utils.destroy_model_parallel() + unset_global_variables() \ No newline at end of file diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 0c722ee0257..961eb63dcdb 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -507,9 +507,9 @@ def golden_rank_result_from_past_code( "world_size, tp_size, cp_size, dp_size", [(8, 1, 2, 4), (8, 1, 1, 8)], # 8 GPUs, 1 TP, 2 CP, 4 DP # 8 GPUs, 1 TP, 1 CP, 8 DP ) -def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): +def test_dynamic_dp_cp_groups(world_size, tp_size, cp_size, dp_size): """ - Test that hybrid DPxCP groups are created correctly. + Test that dynamic DPxCP groups are created correctly. """ Utils.destroy_model_parallel() @@ -520,13 +520,13 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): Utils.initialize_model_parallel( tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, - hybrid_context_parallel=True, + dynamic_context_parallel=True, ) dp_cp_size = ps.get_data_parallel_world_size(with_context_parallel=True) group_sizes = [2**i for i in range(int(log2(dp_cp_size)))][1:] for group_size in group_sizes: - group = ps.get_hybrid_data_context_parallel_groups(group_size=group_size) + group = ps.get_dynamic_data_context_parallel_groups(group_size=group_size) assert group.size() == group_size Utils.destroy_model_parallel()