diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 7c11195f33b..f5e077a2206 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -4,6 +4,7 @@ import logging +import gc import torch from megatron.core.tensor_parallel.random import get_all_rng_states @@ -180,12 +181,10 @@ def __call__(self, *args, **kwargs): torch.cuda.synchronize() torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') - if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: FullCudaGraphWrapper.cuda_graph[training_str].replay() - self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] @@ -196,3 +195,19 @@ def curr_iter(self, stage): def next_iter(self, stage): """Increment current training/validation iteration.""" FullCudaGraphWrapper.curr_iteration[stage] += 1 + + def reset_cuda_graph(self, stage=None): + """Reset CUDA graph.""" + if stage is None or stage == 'training': + if FullCudaGraphWrapper.cuda_graph['training'] is not None: + del FullCudaGraphWrapper.cuda_graph['training'] + FullCudaGraphWrapper.cuda_graph['training'] = None + FullCudaGraphWrapper.result['training'] = None + FullCudaGraphWrapper.curr_iteration['training'] = 0 + if stage is None or stage == 'validation': + if FullCudaGraphWrapper.cuda_graph['validation'] is not None: + del FullCudaGraphWrapper.cuda_graph['validation'] + FullCudaGraphWrapper.cuda_graph['validation'] = None + FullCudaGraphWrapper.result['validation'] = None + FullCudaGraphWrapper.curr_iteration['validation'] = 0 + gc.collect() \ No newline at end of file diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d6496db09fd..f7db77a0bad 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -26,6 +26,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope, ModelType from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule +from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -483,6 +484,12 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offload(param) self.disable_param_offloading = False + def preprocess_for_paged_stash(self): + """Preprocess for paged stash.""" + return paged_stash_init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage + ) + def forward( self, input_ids: Tensor, @@ -519,6 +526,9 @@ def forward( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() + inference_context = deprecate_inference_params(inference_context, inference_params) preproc_output = self._preprocess( @@ -823,6 +833,8 @@ def build_schedule_plan( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 759dc9173ad..4f00736c41f 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -13,6 +13,7 @@ FineGrainedActivationOffloadingInterface as off_interface, ) from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.transformer.moe.paged_stash import paged_stash_reset from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, @@ -637,6 +638,9 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if config.moe_paged_stash: + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + no_sync_func = config.no_sync_func if no_sync_func is None: no_sync_func = contextlib.nullcontext @@ -1063,6 +1067,9 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" + if config.moe_paged_stash: + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2265,6 +2272,9 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if config.moe_paged_stash: + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 976c9df3cd6..16836afbf2b 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -5,6 +5,7 @@ from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass +from contextlib import nullcontext from functools import partial from itertools import chain from math import ceil @@ -43,6 +44,11 @@ get_align_size_for_quantization, skip_routed_expert_padding, ) +from megatron.core.transformer.moe.paged_stash import ( + get_paged_stash_context, + paged_stash_group_commit, + paged_stash_group_start, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( ensure_metadata_has_dp_cp_group, @@ -53,6 +59,7 @@ if HAVE_TE: from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding + import transformer_engine as te else: Fp8Padding, Fp8Unpadding = None, None @@ -942,19 +949,38 @@ def _fused_forward( tokens_per_expert = torch.tensor( tokens_per_expert, dtype=torch.int, device=permuted_probs.device ) - - # Call fused impl - output = ops( - permuted_local_hidden_states, - tokens_per_expert, # FC1 - permuted_probs, # Scaled SwiGLU - tokens_per_expert, # FC2 - ) - + # if the number of tokens is 0, pad the hidden states to 256 + + if self.config.moe_paged_stash: + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) + max_num_tokens = permuted_local_hidden_states.shape[0] + # Average/expected tokens is a pre-padding estimate used by paged stashing heuristics. + # moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled. + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) + stash_context = get_paged_stash_context( + name="grouped_mlp", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + stash_context = nullcontext() + with stash_context: + # Call fused impl + output = ops( + permuted_local_hidden_states, + tokens_per_expert, # FC1 + permuted_probs, # Scaled SwiGLU + tokens_per_expert, # FC2 + ) # Remove padding if needed if unpadded_tokens_per_expert is not None: output = self.quantization_unpadding(output, unpadded_tokens_per_expert) - + if self.config.moe_paged_stash: + output = paged_stash_group_commit(output, name="grouped_mlp") return output def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py new file mode 100644 index 00000000000..493f34dc5e1 --- /dev/null +++ b/megatron/core/transformer/moe/paged_stash.py @@ -0,0 +1,1328 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import logging +from contextlib import nullcontext +from typing import Any + +import torch +import triton +import triton.language as tl + +from megatron.core._rank_utils import log_single_rank +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.full_cuda_graph import FullCudaGraphWrapper +from megatron.core.utils import get_attr_wrapped_model + +logger = logging.getLogger(__name__) + +GLOBAL_BLOCK_SIZE = 1024 +SCALE_INV_BLOCK_SIZE = 32 + + +class PagedStashBuffer: + """ + A paged stash buffer with page-level memory management. + Supports both CUDA and optional pinned host buffer for overflow fallback. + + Buffers are organized as [num_pages, page_size, hidden_size]. + Uses per-buffer free lists (circular buffer) tracked as two-element state: [0]=CUDA, [1]=host. + """ + + def __init__( + self, num_tokens, hidden_size, page_size, device, overflow, host_spill, dtype, num_tokens_host=0 + ): + """ + Args: + num_tokens: Maximum number of tokens the CUDA buffer can hold + hidden_size: Hidden dimension size + page_size: Number of tokens per page + device: Device for the buffer + overflow: Overflow flag tensor (shared across all buffers) + host_spill: Global flag set to 1 if any stash used pinned host (shared) + dtype: Data type + num_tokens_host: If > 0, allocate pinned host buffer with this many tokens for spillover. + """ + self.hidden_size = hidden_size + self.page_size = page_size + self.device = device + self.dtype = dtype + self.overflow = overflow # GPU flag (shared) + self.host_spill = host_spill + + # CUDA buffer + self.num_cuda_pages = (num_tokens + page_size - 1) // page_size + self.total_cuda_tokens = self.num_cuda_pages * page_size + self.cuda_buffer = torch.empty( + (self.total_cuda_tokens, hidden_size), dtype=dtype, device=device + ) + + # Host buffer (pinned), optional + self.num_host_pages = (num_tokens_host + page_size - 1) // page_size if num_tokens_host > 0 else 0 + self.total_host_tokens = self.num_host_pages * page_size if self.num_host_pages > 0 else 0 + if self.num_host_pages > 0: + self.host_buffer = torch.empty( + (self.total_host_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True + ) + else: + self.host_buffer = None + + # Free list state: shape (2,) index 0 = CUDA, 1 = host (all in device memory for kernel) + self.free_list_head = torch.zeros(2, dtype=torch.int64, device=device) + self.free_list_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) + self.free_list_capacity = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) + + # Free list arrays (device memory): page IDs for each buffer + self.free_list_cuda = torch.arange(self.num_cuda_pages, dtype=torch.int64, device=device) + if self.num_host_pages > 0: + self.free_list_host = torch.arange(self.num_host_pages, dtype=torch.int64, device=device) + else: + self.free_list_host = torch.empty(0, dtype=torch.int64, device=device) + + # Pre-allocated reset values (CUDA graph safe: no allocation in reset()) + self._reset_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], + dtype=torch.int64, + device=device, + ) + self._reset_free_list_cuda = torch.arange( + self.num_cuda_pages, dtype=torch.int64, device=device + ) + if self.num_host_pages > 0: + self._reset_free_list_host = torch.arange( + self.num_host_pages, dtype=torch.int64, device=device + ) + else: + self._reset_free_list_host = None + + def reset(self): + """Reset both CUDA and host free lists (CUDA graph safe: no new allocations).""" + self.free_list_cuda.copy_(self._reset_free_list_cuda) + self.free_list_head.zero_() + self.free_list_tail.copy_(self._reset_tail) + if self._reset_free_list_host is not None: + self.free_list_host.copy_(self._reset_free_list_host) + + def __repr__(self): + return ( + f"PagedStashBuffer(num_cuda_pages={self.num_cuda_pages}, num_host_pages={self.num_host_pages}, " + f"page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + ) + + +@triton.jit +def _paged_stash_copy_kernel( + src_ptr, + cuda_dst_ptr, + host_dst_ptr, + num_tokens_ptr, + free_list_cuda_ptr, + free_list_host_ptr, + free_list_head_ptr, # shape (2,): [cuda_head, host_head] + free_list_tail_ptr, # shape (2,) + free_list_capacity_ptr, + page_record_ptr, + overflow_ptr, + host_spill_global_ptr, # 1 if any successful host spill (not set on overflow path) + spilled_to_host_ptr, # Output: 0 = stored in CUDA, 1 = stored in host or overflow + new_free_list_head_ptr, # Output: shape (2,) updated heads + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_HOST_BUFFER: tl.constexpr, +): + """Copy tokens to paged stash: try CUDA first (fast path), then host if CUDA full.""" + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) + + num_tokens = tl.load(num_tokens_ptr) + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + + # Common case: load only CUDA state (and head_host for output when use_cuda) + head_cuda = tl.load(free_list_head_ptr) + head_host = tl.load(free_list_head_ptr + 1) + tail_cuda = tl.load(free_list_tail_ptr) + cap_cuda = tl.load(free_list_capacity_ptr) + + avail_cuda = tail_cuda - head_cuda + use_cuda = avail_cuda >= required_pages + + # Assume CUDA path: set everything for GPU stash + spill = 0 + dst_ptr = cuda_dst_ptr + free_list_ptr = free_list_cuda_ptr + head = head_cuda + cap = cap_cuda + new_head_cuda = head_cuda + required_pages + new_head_host = head_host + + if overflow == 1: + # No stash; preserve heads so Python copy_ does not write garbage into the buffer. + if pid == 0: + tl.store(new_free_list_head_ptr, head_cuda) + tl.store(new_free_list_head_ptr + 1, head_host) + return + + # Only when CUDA is full: load host state and maybe switch to host + if not use_cuda: + tail_host = tl.load(free_list_tail_ptr + 1) + cap_host = tl.load(free_list_capacity_ptr + 1) + use_host = HAS_HOST_BUFFER == 1 and (tail_host - head_host) >= required_pages + if use_host: + spill = 1 + dst_ptr = host_dst_ptr + free_list_ptr = free_list_host_ptr + head = head_host + cap = cap_host + new_head_cuda = head_cuda + new_head_host = head_host + required_pages + else: + if pid == 0: + tl.store(overflow_ptr, 1) + tl.store(spilled_to_host_ptr, 1) + tl.store(new_free_list_head_ptr, head_cuda) + tl.store(new_free_list_head_ptr + 1, head_host) + return + + if pid == 0: + tl.store(spilled_to_host_ptr, spill) + if spill == 1: + tl.store(host_spill_global_ptr, 1) + + # Copy loop: strided over tokens + token_idx = pid + while token_idx < num_tokens: + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + free_list_idx = (head + page_slot) % cap + page_id = tl.load(free_list_ptr + free_list_idx) + if token_in_page == 0: + tl.store(page_record_ptr + page_slot, page_id) + dst_token_idx = page_id * PAGE_SIZE + token_in_page + + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + token_idx_i64 = token_idx.to(tl.int64) + dst_token_idx_i64 = dst_token_idx.to(tl.int64) + src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx_i64 * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + token_idx += num_blocks + + if pid == 0: + tl.store(new_free_list_head_ptr, new_head_cuda) + tl.store(new_free_list_head_ptr + 1, new_head_host) + + +@triton.jit +def _paged_stash_pop_kernel( + cuda_src_ptr, + host_src_ptr, + dst_ptr, + num_tokens_ptr, + page_record_ptr, + spilled_to_host_ptr, # 0 = read from CUDA, 1 = read from host + overflow_ptr, + free_list_cuda_ptr, + free_list_host_ptr, + free_list_tail_ptr, # shape (2,) + free_list_capacity_ptr, + new_free_list_tail_ptr, # Output: shape (2,) updated tails + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Reload tokens from paged stash; CUDA path fast, host path when spilled_to_host.""" + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) + + num_tokens = tl.load(num_tokens_ptr) + spill = tl.load(spilled_to_host_ptr) + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + + # Common case: load only CUDA state (and tail_host for output when spill=0) + tail_cuda = tl.load(free_list_tail_ptr) + tail_host = tl.load(free_list_tail_ptr + 1) + cap_cuda = tl.load(free_list_capacity_ptr) + + if overflow == 1: + # No pop; preserve tails so Python copy_ does not write garbage into the buffer. + if pid == 0: + tl.store(new_free_list_tail_ptr, tail_cuda) + tl.store(new_free_list_tail_ptr + 1, tail_host) + return + + # Assume CUDA path + src_ptr = cuda_src_ptr + free_list_ptr = free_list_cuda_ptr + tail = tail_cuda + cap = cap_cuda + new_tail_cuda = tail_cuda + required_pages + new_tail_host = tail_host + + # Only when spilled to host: load host state and switch + if spill == 1: + cap_host = tl.load(free_list_capacity_ptr + 1) + if cap_host == 0: + # Cannot pop from host; preserve tails (no-op for free-list state). + if pid == 0: + tl.store(new_free_list_tail_ptr, tail_cuda) + tl.store(new_free_list_tail_ptr + 1, tail_host) + return + src_ptr = host_src_ptr + free_list_ptr = free_list_host_ptr + tail = tail_host + cap = cap_host + new_tail_cuda = tail_cuda + new_tail_host = tail_host + required_pages + + token_idx = pid + while token_idx < num_tokens: + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + page_id = tl.load(page_record_ptr + page_slot) + src_token_idx = page_id * PAGE_SIZE + token_in_page + + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + src_token_idx_i64 = src_token_idx.to(tl.int64) + token_idx_i64 = token_idx.to(tl.int64) + src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + token_idx_i64 * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + if token_in_page == 0: + write_idx = (tail + page_slot) % cap + tl.store(free_list_ptr + write_idx, page_id) + token_idx += num_blocks + + if pid == 0: + tl.store(new_free_list_tail_ptr, new_tail_cuda) + tl.store(new_free_list_tail_ptr + 1, new_tail_host) + + +class PagedTensor: + """ + A paged tensor that stores data in pages within a paged stash buffer. + """ + + def __init__( + self, + tensor, + num_tokens_tensor=None, + avg_num_tokens: int = None, + vp_stage=None, + original_shape=None, + schedule_layer_no=None, + is_columnwise_scale_inv=None, + max_num_tokens=None, + hidden_size=None, + page_size=64, + ): + """ + Args: + tensor: The tensor to store + num_tokens_tensor: Scalar tensor containing actual number of tokens + vp_stage: Virtual pipeline stage + layer_name: Name of the layer + max_num_tokens: Maximum number of tokens + hidden_size: Hidden size + page_size: Number of tokens per page + """ + self._tensor = tensor + self._original_tensor = None + assert ( + num_tokens_tensor is not None + and isinstance(num_tokens_tensor, torch.Tensor) + and num_tokens_tensor.numel() == 1 + ) + self.num_tokens_tensor = num_tokens_tensor.clone() + self.avg_num_tokens = avg_num_tokens + self.vp_stage = vp_stage + self.schedule_layer_no = schedule_layer_no + self.is_columnwise_scale_inv = is_columnwise_scale_inv + self.max_num_tokens = max_num_tokens + self.hidden_size = hidden_size + self.page_size = page_size + + # Original tensor information + self.original_shape = list(tensor.shape) if original_shape is None else original_shape + self.element_size = tensor.element_size() + self.dtype = tensor.dtype + self.device = tensor.device + + # Calculate number of pages needed + self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division + + # Page record: stores which pages are being used for this tensor + self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + # Set by copy kernel: 0 = data in CUDA stash, 1 = data in host (pinned) stash + self.spilled_to_host = torch.zeros(1, dtype=torch.int64, device=self.device) + + @property + def schedule_layer(self): + """Get the schedule layer.""" + return self.schedule_layer_no + + def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Offload the paged tensor to paged stash buffer (CUDA or host if CUDA full).""" + self._tensor = self._tensor.contiguous() + if self.num_tokens_tensor.dim() == 0: + self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) + if self.is_columnwise_scale_inv: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE + else: + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens + + tensor_to_copy = self._tensor + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(max_num_tokens, max_blocks) + grid = (num_blocks,) + + new_free_list_head = torch.empty(2, dtype=torch.int64, device=self.device) + has_host = 1 if paged_stash_buffer.host_buffer is not None else 0 + host_dst = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) + + _paged_stash_copy_kernel[grid]( + tensor_to_copy.view(paged_stash_buffer.cuda_buffer.dtype), + paged_stash_buffer.cuda_buffer, + host_dst, + num_tokens_tensor, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, + paged_stash_buffer.overflow, + paged_stash_buffer.host_spill, + self.spilled_to_host, + new_free_list_head, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + HAS_HOST_BUFFER=has_host, + ) + paged_stash_buffer.free_list_head.copy_(new_free_list_head) + self._original_tensor = self._tensor + self._tensor = None + + def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Reload the paged tensor from paged stash buffer (CUDA or host from spilled_to_host).""" + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + tensor_to_reload = self._tensor + + if self.is_columnwise_scale_inv: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE + else: + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(max_num_tokens, max_blocks) + grid = (num_blocks,) + + new_free_list_tail = torch.empty(2, dtype=torch.int64, device=self.device) + host_src = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.cuda_buffer, + host_src, + tensor_to_reload.view(paged_stash_buffer.cuda_buffer.dtype), + num_tokens_tensor, + self.page_record, + self.spilled_to_host, + paged_stash_buffer.overflow, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + + +class PipelinePreScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.stash_manager = stash_manager + # Wait for stash to complete before starting the next layer + stash_manager.wait_for_stash_to_complete() + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + # Initiate reload for next layer + if ( + ctx.stash_manager.status == 'captured' + and ctx.stash_manager.current_schedule_index < len(ctx.stash_manager._pp_schedule) + ): + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + ] + if next_schedule_layer < 0: + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer) + + return grad_output + (None, None) + + +class PipelinePostScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.stash_manager = stash_manager + ctx.vp_stage = stash_manager.current_vp_stage + if ctx.vp_stage is None: + ctx.vp_stage = 0 + ctx.layer_no, ctx.microbatch_no = stash_manager.update_pp_schedule(ctx.vp_stage + 1) + + # Initiate stash for current layer and reload for next layer + if stash_manager.status == 'captured': + current_schedule_layer = stash_manager.get_schedule_layer( + ctx.vp_stage + 1, ctx.layer_no, ctx.microbatch_no + ) + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + 1 + ] + if current_schedule_layer != -next_schedule_layer: + # Start stash for current layer + ctx.stash_manager.stash_paged_tensors(current_schedule_layer) + if next_schedule_layer < 0: + # reload for next backward layer + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer, no_wait=True) + else: + ctx.stash_manager.remove_paged_tensor_from_stash() + + ctx.stash_manager.current_schedule_index += 1 + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + if ctx.vp_stage is not None: + ctx.stash_manager.update_pp_schedule( + -(ctx.vp_stage + 1), -ctx.layer_no, -ctx.microbatch_no + ) + ctx.stash_manager.current_schedule_index += 1 + current_stream = torch.cuda.current_stream() + + ctx.stash_manager.wait_for_stash_to_complete() + if ctx.stash_manager._unpack_stream_status == 'reloading': + current_stream.wait_stream(ctx.stash_manager.unpack_stream) + ctx.stash_manager._unpack_stream_status = 'idle' + + return grad_output + (None, None) + + +class PagedStashManager: + """ + Singleton manager for coordinating paged stashing across pipeline stages. + Manages chunk handlers, synchronizes GPU-GPU transfers, + and handles virtual pipeline parallelism + """ + + STASH_MGR = None + + @classmethod + def get_instance(cls): + """Get the singleton instance of PagedStashManager.""" + if cls.STASH_MGR is None: + cls.STASH_MGR = PagedStashManager() + return cls.STASH_MGR + + def __init__(self): + """Initialize the manager with queues and dedicated CUDA streams.""" + # allocate streams and events for synchronization + self.enabled = False + self._pack_stream = torch.cuda.Stream() + # Currently paged stashing is not stream-safe, so use the same stream for packing + # and unpacking + self._unpack_stream = self._pack_stream + self._pack_stream_status = 'idle' # idle, stashing + self._unpack_stream_status = 'idle' # idle, reloading + self.paged_tensors_to_stash = [] + self.paged_tensors_stash_in_progress = [] + self.paged_tensors_to_reload = {} + + self.iteration = 0 + self._current_layer_name = None + self.vp_size = None + self.current_vp_stage = None + self.status = 'begin' # begin, capture, captured + # If element is +ve, it denotes forward pass of vp stage, + # if -ve, it denotes backward pass of vp stage + self._pp_schedule = None + self.current_layer = None + self.current_microbatch = None + self.current_schedule_index = None + + # Track max tokens needed across all vp_stages grouped by dtype and hidden_size + self.max_tokens_across_vp_stages = None + self.temp_tokens_across_vp_stages = None + # Track max tokens computed from avg_num_tokens (heuristic) across all vp_stages + self.max_avg_tokens_across_vp_stages = None + self.temp_avg_tokens_across_vp_stages = None + + self.num_tokens_tensor = None + self.max_num_tokens = None + # Optional hint: expected/average number of tokens (e.g., pre-padding estimate) + self.avg_num_tokens = None + self.stash_buffers = None + self.overflow = None + self.host_spill = None + self.device = None + + # Page size for paged memory management (default; overwritten from config in paged_stash_reset) + self.page_size = 64 + + @property + def pack_stream(self): + """Get the pack stream.""" + return self._pack_stream + + @property + def unpack_stream(self): + """Get the unpack stream.""" + return self._unpack_stream + + def set_current_layer_name(self, name): + """Set the current layer name.""" + self._current_layer_name = name + + def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): + """Get the schedule layer.""" + assert layer_no < 1000 and microbatch_no < 1000, "Schedule encoding overflow" + return vp_stage * 1000000 + layer_no * 1000 + microbatch_no + + def add_paged_tensor_to_stash(self, paged_tensor): + """Add a paged tensor to the stash list.""" + if self.status == 'captured': + self.paged_tensors_to_stash.append(paged_tensor) + else: + pass + + def remove_paged_tensor_from_stash(self): + """Remove all paged tensors from the stash list.""" + if self.status == 'captured': + self.paged_tensors_to_stash.clear() + else: + pass + + def stash_paged_tensors(self, pp_schedule_layer): + """Stash the paged tensors.""" + current_stream = torch.cuda.current_stream() + self.pack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.pack_stream): + if self.status == 'captured': + self._pack_stream_status = 'stashing' + if pp_schedule_layer not in self.paged_tensors_to_reload: + self.paged_tensors_to_reload[pp_schedule_layer] = [] + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) + while len(self.paged_tensors_to_stash) > 0: + paged_tensor = self.paged_tensors_to_stash.pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.offload_to_stash(stash_buffer) + self.paged_tensors_to_reload[pp_schedule_layer].append(paged_tensor) + self.paged_tensors_stash_in_progress.append(paged_tensor) + else: + pass + assert ( + len(self.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + + def wait_for_stash_to_complete(self): + """Wait for stash to complete.""" + current_stream = torch.cuda.current_stream() + if self._pack_stream_status == 'stashing': + current_stream.wait_stream(self.pack_stream) + self._pack_stream_status = 'idle' + + # Deallocate original tensor after stash is complete + while len(self.paged_tensors_stash_in_progress) > 0: + paged_tensor = self.paged_tensors_stash_in_progress.pop(0) + paged_tensor._original_tensor = None + + def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): + """Reload the paged tensors.""" + # Avoid waiting for main stream if reload is immediately after stash + # since stash is already waiting for main stream + if not no_wait or self.unpack_stream != self.pack_stream: + current_stream = torch.cuda.current_stream() + self.unpack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.unpack_stream): + if self.status == 'captured': + self._unpack_stream_status = 'reloading' + while len(self.paged_tensors_to_reload[pp_schedule_layer]) > 0: + paged_tensor = self.paged_tensors_to_reload[pp_schedule_layer].pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.reload_from_stash(stash_buffer) + else: + pass + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) + + def allocate_stash_buffers( + self, + moe_paged_stash_buffer_size_factor_cuda: float = 1.10, + moe_paged_stash_buffer_size_factor_cpu: float = 0.0, + ): + """Allocate stash buffers organized by [dtype][hidden_size].""" + self.stash_buffers = {} + if self.overflow is None: + self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + else: + self.overflow.zero_() + if self.host_spill is None: + self.host_spill = torch.zeros(1, dtype=torch.int64, device=self.device) + else: + self.host_spill.zero_() + + cuda_factor = moe_paged_stash_buffer_size_factor_cuda + cpu_factor = moe_paged_stash_buffer_size_factor_cpu + + # Both factors use the same sign convention: + # - positive: size based on avg_num_tokens-derived maxima + # - negative: size based on actual num_tokens-derived maxima (legacy behavior) + # Scale is always abs(factor). For CPU, 0 means no host buffer. + if cuda_factor >= 0: + max_tokens_dict = self.max_avg_tokens_across_vp_stages + cuda_scale = cuda_factor + else: + max_tokens_dict = self.max_tokens_across_vp_stages + cuda_scale = -cuda_factor + + # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict. + if not max_tokens_dict: + max_tokens_dict = self.max_tokens_across_vp_stages + + if cpu_factor > 0: + host_tokens_dict = self.max_avg_tokens_across_vp_stages or self.max_tokens_across_vp_stages + cpu_scale = cpu_factor + elif cpu_factor < 0: + host_tokens_dict = self.max_tokens_across_vp_stages + cpu_scale = -cpu_factor + else: + host_tokens_dict = None + cpu_scale = 0.0 + + if max_tokens_dict is None: + log_single_rank( + logger, + logging.INFO, + "Paged stash: max_tokens_dict is None, skipping stash buffer allocation", + ) + return + for dtype, hidden_size in max_tokens_dict: + if dtype not in self.stash_buffers: + self.stash_buffers[dtype] = {} + assert hidden_size not in self.stash_buffers[dtype] + num_tokens = int(max_tokens_dict[dtype, hidden_size] * cuda_scale) + num_tokens_host = ( + int(host_tokens_dict[dtype, hidden_size] * cpu_scale) + if host_tokens_dict is not None and (dtype, hidden_size) in host_tokens_dict + else 0 + ) + buf_dtype = torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype + self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( + num_tokens, + hidden_size, + self.page_size, + self.device, + self.overflow, + self.host_spill, + buf_dtype, + num_tokens_host=num_tokens_host, + ) + sb = self.stash_buffers[dtype][hidden_size] + msg = f'allocate_stash_buffers cuda: {sb.cuda_buffer.shape}' + if sb.host_buffer is not None: + msg += f' host: {sb.host_buffer.shape}' + msg += f' dtype={sb.dtype} ({dtype})' + log_single_rank(logger, logging.INFO, msg) + + def release_stash_buffers(self): + """Drop large stash CUDA/host page buffers after full-iteration CUDA graph teardown (fallback). + + Shared ``overflow`` / ``host_spill`` scalars are retained (small). Reallocation of page + buffers happens on the next ``paged_stash_reset`` while status remains ``captured``. + """ + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.stash_buffers = None + log_single_rank( + logger, + logging.INFO, + "Paged stash: released stash page buffers after fallback (reallocated on next stash reset).", + ) + + def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): + """Update the pp schedule.""" + if self._pp_schedule is None: + self._pp_schedule = [] + + assert self.vp_size is not None + if layer_no is None: + # forward pass + vp_stage_index = vp_stage - 1 + layer_no = self.current_layer[vp_stage_index] + self.current_layer[vp_stage_index] += 1 + microbatch_no = self.current_microbatch[vp_stage_index] + + if self.status == 'capture': + self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) + + expected = self.get_schedule_layer(vp_stage, layer_no, microbatch_no) + actual = self._pp_schedule[self.current_schedule_index] + assert actual == expected, f"schedule {actual} != {expected}" + + return layer_no, microbatch_no + + + def update_model_chunk(self, vp_stage_index): + """Update layer=1, increment microbatch of new vp vp_stage.""" + if self.current_layer is None: + # current layer and microbatch for each vp stage for forward pass + self.current_layer = [1 for _ in range(self.vp_size)] + self.current_microbatch = [0 for _ in range(self.vp_size)] + self.current_layer[vp_stage_index] = 1 + self.current_microbatch[vp_stage_index] += 1 + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """ + Hook called when autograd saves a tensor for backward pass. + Returns a tag to identify the tensor later. + """ + # Handle 0-dim tensors (torch.Size([])) - they have no size(0) + if ( + self.max_num_tokens is None + or tensor.dim() == 0 + or not hasattr(tensor, 'grouped_tensor_scale_inv') + ): + return tensor.detach() + + assert isinstance(tensor, torch.Tensor), f"tensor is not a torch.Tensor {type(tensor)}" + + original_shape = tensor.shape + columnwise_scale_inv = tensor.grouped_tensor_scale_inv + tensor = tensor.flatten() + dtype = tensor.dtype + hidden_size = tensor.numel() // (self.max_num_tokens if not columnwise_scale_inv else self.max_num_tokens // SCALE_INV_BLOCK_SIZE) + + if self.max_tokens_across_vp_stages is None: + self.max_tokens_across_vp_stages = {} + self.temp_tokens_across_vp_stages = {} + self.max_avg_tokens_across_vp_stages = {} + self.temp_avg_tokens_across_vp_stages = {} + + avg_num_tokens = None + if self.status == 'capture': + + self.num_tokens = self.num_tokens_tensor.item() + actual_num_tokens = self.num_tokens // SCALE_INV_BLOCK_SIZE if columnwise_scale_inv else self.num_tokens + + avg_num_tokens = ( + int(self.avg_num_tokens) if self.avg_num_tokens is not None else None + ) + + if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 + + self.temp_tokens_across_vp_stages[dtype, hidden_size] += actual_num_tokens + self.max_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_tokens_across_vp_stages[dtype, hidden_size], + self.temp_tokens_across_vp_stages[dtype, hidden_size], + ) + + # Track avg tokens across vp stages (if provided) using the same accumulation model. + if avg_num_tokens is not None: + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += (avg_num_tokens if not columnwise_scale_inv else avg_num_tokens // SCALE_INV_BLOCK_SIZE) + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_avg_tokens_across_vp_stages[dtype, hidden_size], + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size], + ) + + # Since capture stage does not use CUDA graph, we can truncate + # the saved tensor to actual num_tokens + new_size = (actual_num_tokens * hidden_size,) + + tensor_truncated = torch.empty(new_size, dtype=dtype, device=tensor.device) + tensor_truncated.copy_(tensor[: actual_num_tokens * hidden_size]) + tensor = tensor_truncated + + tensor.grouped_tensor_scale_inv = columnwise_scale_inv + paged_tensor = PagedTensor( + tensor, + num_tokens_tensor=self.num_tokens_tensor, + avg_num_tokens=avg_num_tokens, + vp_stage=self.current_vp_stage, + original_shape=original_shape, + schedule_layer_no=( + self._pp_schedule[self.current_schedule_index] + if self._pp_schedule is not None + and self.current_schedule_index < len(self._pp_schedule) + else None + ), + is_columnwise_scale_inv=columnwise_scale_inv, + max_num_tokens=self.max_num_tokens, + hidden_size=hidden_size, + page_size=self.page_size, + ) + + if self.status == 'captured': + self.add_paged_tensor_to_stash(paged_tensor) + return paged_tensor + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """ + Hook called when autograd retrieves a saved tensor during backward pass. + Returns the actual tensor (potentially reloading from CPU). + """ + if isinstance(saved_state, (PagedTensor)): + columnwise_scale_inv = saved_state.is_columnwise_scale_inv + if self.status == 'capture': + num_tokens = saved_state.num_tokens_tensor.item() + key = (saved_state.dtype, saved_state.hidden_size) + if key in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[key] -= (num_tokens if not columnwise_scale_inv else num_tokens // SCALE_INV_BLOCK_SIZE) + if ( + saved_state.avg_num_tokens is not None + and key in self.temp_avg_tokens_across_vp_stages + ): + self.temp_avg_tokens_across_vp_stages[key] -= (int(saved_state.avg_num_tokens) if not columnwise_scale_inv else int(saved_state.avg_num_tokens) // SCALE_INV_BLOCK_SIZE) + + # Handle 1-byte tensors (torch.uint8) + dtype = saved_state._tensor.dtype + if saved_state._tensor.element_size() == 1: + saved_state._tensor = saved_state._tensor.view(torch.uint8) + + # Pad the tensor to the max number of tokens + # check if the tensor is 1D + assert saved_state._tensor.ndim == 1, f"saved_state._tensor.ndim is not 1 {saved_state._tensor.ndim}" + npad = (self.max_num_tokens - num_tokens) * saved_state.hidden_size + if columnwise_scale_inv: + npad = npad // SCALE_INV_BLOCK_SIZE + pad = (0, npad) + saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad).view(dtype) + + assert ( + saved_state._tensor is not None + ), f"saved_state._tensor is None {saved_state._tensor}" + + # Record cross-stream usage (important when tensor was produced on another stream). + if isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda: + saved_state._tensor.record_stream(torch.cuda.current_stream()) + + return saved_state._tensor.view(saved_state.original_shape) + + return saved_state + + +class PagedStashContext: + """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" + + def __init__(self, stash_manager): + self.stash_manager = stash_manager + self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( + stash_manager.on_save_for_backward, stash_manager.on_get_saved_tensor + ) + + def __enter__(self): + result = self.saved_tensors_context.__enter__() + + # Add more custom logic after entering if needed + return result + + def __exit__(self, *args: Any): + # Call the underlying context manager's __exit__ + result = self.saved_tensors_context.__exit__(*args) + return result + + +def paged_stash_group_start(tensor): + """Mark the start of a layer group and prepare for stash/reload.""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return tensor + return PipelinePreScheduleFunction.apply(tensor, stash_manager) + + +def get_paged_stash_context( + name=None, + max_num_tokens=None, + num_tokens_tensor=None, + avg_num_tokens=None, +): + """Get the paged stash context""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return nullcontext() + stash_manager.max_num_tokens = max_num_tokens + stash_manager.avg_num_tokens = avg_num_tokens + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) + stash_manager.num_tokens_tensor = num_tokens_tensor + stash_manager.set_current_layer_name(name) if name is not None else None + pack_unpack_context = PagedStashContext(stash_manager) + return pack_unpack_context + + +def paged_stash_group_commit(tensor, name=None): + """Mark the end of a layer group and prepare for stash/reload.""" + stash_manager = PagedStashManager.get_instance() + stash_manager.device = tensor.device + if not stash_manager.enabled: + return tensor + return PipelinePostScheduleFunction.apply(tensor, stash_manager) + + +def paged_stash_init_chunk_handler(vp_size, vp_stage): + """Initialize the chunk handler, called at the start of a microbatch forward pass.""" + stash_manager = PagedStashManager.get_instance() + stash_manager.vp_size = vp_size if vp_size is not None else 1 + stash_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 + stash_manager.update_model_chunk(stash_manager.current_vp_stage) + +def paged_stash_reset(enabled=True, config=None): + """Reset the chunk handler, called at the start of a training iteration. + + config: optional TransformerConfig; if provided, moe_paged_stash_buffer_size_factor_cuda/cpu and + moe_paged_stash_page_size are read from it. Otherwise defaults to 1.10 (CUDA), 0.0 (CPU). + """ + stash_manager = PagedStashManager.get_instance() + stash_manager.enabled = enabled + stash_manager.iteration += 1 + if config is not None: + stash_manager.page_size = config.moe_paged_stash_page_size + # current layer and microbatch for each vp stage for forward pass + stash_manager.current_schedule_index = 0 + + if not enabled: + return + + if stash_manager.status == 'begin': + stash_manager.status = 'capture' + elif stash_manager.status == 'capture': + stash_manager.status = 'captured' + cuda_factor = config.moe_paged_stash_buffer_size_factor_cuda if config is not None else 1.10 + cpu_factor = config.moe_paged_stash_buffer_size_factor_cpu if config is not None else 0.0 + stash_manager.allocate_stash_buffers( + moe_paged_stash_buffer_size_factor_cuda=cuda_factor, + moe_paged_stash_buffer_size_factor_cpu=cpu_factor, + ) + elif stash_manager.status == 'captured': + # Buffers may have been released after a PagedStashRunner fallback; reallocate using + # the same capture-derived maxima and current config factors. + if stash_manager.stash_buffers is None: + cuda_factor = config.moe_paged_stash_buffer_size_factor_cuda if config is not None else 1.10 + cpu_factor = config.moe_paged_stash_buffer_size_factor_cpu if config is not None else 0.0 + stash_manager.allocate_stash_buffers( + moe_paged_stash_buffer_size_factor_cuda=cuda_factor, + moe_paged_stash_buffer_size_factor_cpu=cpu_factor, + ) + + if stash_manager.status == 'captured': + assert stash_manager.stash_buffers is not None, ( + "Paged stash: captured state but stash_buffers is None after reset/allocation." + ) + for dtype in stash_manager.stash_buffers.keys(): + for hidden_size in stash_manager.stash_buffers[dtype].keys(): + stash_manager.stash_buffers[dtype][hidden_size].reset() + stash_manager.overflow.zero_() + stash_manager.host_spill.zero_() + stash_manager.current_layer = [1 for _ in range(stash_manager.vp_size)] + stash_manager.current_microbatch = [0 for _ in range(stash_manager.vp_size)] + assert ( + len(stash_manager.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" + assert len(stash_manager.paged_tensors_stash_in_progress) == 0, ( + f"paged_tensors_stash_in_progress is not empty " + f"{stash_manager.paged_tensors_stash_in_progress}" + ) + +def check_paged_stash_overflow(): + """Check if paged stash overflow""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled or stash_manager.overflow is None: + return torch.zeros(1, dtype=torch.bool, device='cuda') + overflow = stash_manager.overflow.ne(0) + return overflow + + +def check_paged_stash_host_spill(): + """True if any activation was stashed to pinned host (successful spill, not overflow path).""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled or stash_manager.host_spill is None: + return torch.zeros(1, dtype=torch.bool, device='cuda') + return stash_manager.host_spill.ne(0) + + +class PagedStashRunner: + """Runner for paged stash""" + + def __init__(self, config, copy_main_params, model, optimizer, forward_backward_func): + self.stash_manager = PagedStashManager.get_instance() + self.config = config + self.copy_main_params = copy_main_params + self.model = model + self.optimizer = optimizer + self.forward_backward_func = forward_backward_func + self.moe_layers = [] + for model_chunk in self.model: + model_with_decoder = get_attr_wrapped_model( + model_chunk, "decoder", allow_none=False, return_model_obj=True + ) + for layer in model_with_decoder.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + if model_with_decoder.mtp_process: + for layer in model_with_decoder.mtp.layers: + mlp = layer.mtp_model_layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + + def data_read(self, data_iterator, model, training, num_microbatches): + """Read all microbatch inputs from Dataloader and copy to static buffers.""" + data_iterator_saved = [] + if not isinstance(model, list) or len(model) == 1: + assert not isinstance(data_iterator, list) or len(data_iterator) == 1 + iterator0 = data_iterator if not isinstance(data_iterator, list) else data_iterator[0] + data_list = [] + if iterator0 is not None: + for b in range(num_microbatches): + data_list.append(next(iterator0)) + data_iterator_saved.append(iter(data_list)) + data_list = [iter(data_list)] + else: + data_iterator_saved.append(None) + data_list.append(None) + else: + assert isinstance(data_iterator, list) and len(data_iterator) == len(model) + data_list = [] + for i in range(len(model)): + if data_iterator[i] is not None: + data_list_i = [] + for b in range(num_microbatches): + data_list_i.append(next(data_iterator[i])) + data_iterator_saved.append(iter(data_list_i)) + data_list.append(iter(data_list_i)) + else: + data_iterator_saved.append(None) + data_list.append(None) + return data_iterator_saved, data_list + + def check_moe_overflow(self): + """(stash_overflow_rank_sum, overbudget_rank_sum, host_spill_rank_sum); one all_reduce.""" + stash_overflow = check_paged_stash_overflow().view(-1)[0] + host_spill = check_paged_stash_host_spill().view(-1)[0] + overbudget = torch.zeros(1, dtype=torch.bool, device=stash_overflow.device).view(-1)[0] + for mlp in self.moe_layers: + ob = mlp.token_dispatcher.check_over_budget() + if ob is not None: + overbudget |= ob.view(-1)[0] + + flags = torch.stack( + [ + stash_overflow.to(torch.int32), + overbudget.to(torch.int32), + host_spill.to(torch.int32), + ], + dim=0, + ) + torch.distributed.all_reduce(flags, op=torch.distributed.ReduceOp.SUM) + return flags[0].item(), flags[1].item(), flags[2].item() + + def prepare_for_rerun(self, is_training=True): + """Prepare for rerun""" + log_single_rank( + logger, + logging.INFO, + "Paged stash: rerunning forward-backward without moe_expert_rank_capacity_factor padding " + "and with moe_paged_stash disabled.", + ) + # check for token dispatcher overflow + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = None + mlp.token_dispatcher.reset_over_budget() + if self.stash_manager.overflow is not None: + self.stash_manager.overflow.zero_() + if self.stash_manager.host_spill is not None: + self.stash_manager.host_spill.zero_() + self.config.moe_paged_stash = False + + # Set grad to zero. + for model_chunk in self.model: + model_chunk.zero_grad_buffer() + if self.optimizer is not None: + self.optimizer.zero_grad() + + #_handle_mxfp8_param_buffer_copy + if self.copy_main_params: + def _try_copy_main_params(opt): + if isinstance(opt, DistributedOptimizer) and hasattr(opt, 'shard_fp32_from_float16_groups'): + opt._copy_main_params_to_param_buffer() + # Handle both ChainedOptimizer and direct DistributedOptimizer cases + # Note: FSDP's DistributedOptimizer doesn't have shard_fp32_from_float16_groups, + # so we check for this attribute before calling _copy_main_params_to_param_buffer + if self.optimizer is not None: + if hasattr(self.optimizer, 'chained_optimizers'): + for optim_instance in self.optimizer.chained_optimizers: + _try_copy_main_params(optim_instance) + else: + _try_copy_main_params(self.optimizer) + + # Delete the CUDA graph before releasing stash tensors the captured graph may reference. + if isinstance(self.forward_backward_func, FullCudaGraphWrapper): + self.forward_backward_func.reset_cuda_graph(stage='training' if is_training else 'validation') + + # Only drop page buffers on training fallback. Validation uses forward_only=True, so + # paged_stash_reset disables the stash manager and eval forward never reads/writes the + # large page buffers—freeing them here saves almost nothing. If we released on eval, + # the next training step would realloc new buffer addresses while the training + # FullCudaGraphWrapper could still replay a graph recorded against the old pointers. + # Training fallback resets the training graph before this path, so release + realloc + # remains consistent with capture. + if is_training: + self.stash_manager.release_stash_buffers() + + def __call__(self, *args, **kwargs): + """Run the paged stash""" + assert len(args) == 0, 'forward_backward_func does not accept positional args' + assert all( + [ + kwarg in kwargs + for kwarg in [ + 'model', + 'data_iterator', + 'num_microbatches', + 'seq_length', + 'forward_only', + ] + ] + ) + model = kwargs['model'] + num_microbatches = kwargs['num_microbatches'] + + training = not kwargs['forward_only'] + data_iterator = kwargs['data_iterator'] + saved_moe_paged_stash = self.config.moe_paged_stash + num_tries = 0 + while True: + assert num_tries < 2, f"PagedStashRunner: num_tries {num_tries} exceeded max attempts!!!" + num_tries += 1 + data_iterator, data_list = self.data_read(data_iterator, model, training, num_microbatches) + + kwargs['data_iterator'] = data_list + result = self.forward_backward_func(*args, **kwargs) + + stash_overflow_ranks, overbudget_ranks, host_spill_ranks = self.check_moe_overflow() + # if no overflow, set the expert_rank_capacity_factor to the original value + if stash_overflow_ranks == 0 and overbudget_ranks == 0: + if host_spill_ranks > 0: + log_single_rank( + logger, + logging.INFO, + "Paged stash: spilled activations to pinned host " + f"on {host_spill_ranks} rank(s) (CUDA stash full). " + "Consider increasing moe_paged_stash_buffer_size_factor_cuda for potentially better performance.", + ) + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = mlp.token_dispatcher.config.moe_expert_rank_capacity_factor + self.config.moe_paged_stash = saved_moe_paged_stash + break + + # if overflow or overbudget, set the expert_rank_capacity_factor to None + if overbudget_ranks > 0: + log_single_rank( + logger, + logging.INFO, + "Paged stash: token drop during MoE token dispatch (over budget) " + f"on {overbudget_ranks} rank(s). " + "Consider increasing moe_expert_rank_capacity_factor.", + ) + if stash_overflow_ranks > 0: + log_single_rank( + logger, + logging.INFO, + "Paged stash: stashing buffer overflow " + f"on {stash_overflow_ranks} rank(s). " + "Consider increasing moe_paged_stash_buffer_size_factor_cuda or moe_paged_stash_buffer_size_factor_cpu.", + ) + self.prepare_for_rerun(is_training=training) + return result diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 3d353666b70..2db080dfb9e 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -76,6 +76,7 @@ def __init__( self.tp_size = utils.get_pg_size(self.tp_group) self.tp_rank = utils.get_pg_rank(self.tp_group) self.ep_size = utils.get_pg_size(self.ep_group) + self.ep_rank = utils.get_pg_rank(self.ep_group) # Attributes that need to be captured in cudagraph. These attributes are returned # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. @@ -1021,10 +1022,23 @@ def __init__( "https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep." ) + self.moe_expert_rank_capacity_factor = self.config.moe_expert_rank_capacity_factor + self.over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') + def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) + + if self.moe_expert_rank_capacity_factor is not None: + pad_multiple = get_align_size_for_quantization(self.config) + budget = int( + routing_map.shape[0] + * self.config.moe_router_topk + * self.moe_expert_rank_capacity_factor + ) + budget += -budget % pad_multiple + self.num_permuted_tokens = budget # Compute the capacity for each expert at the drop_and_pad mode if self.drop_and_pad: num_out_tokens = num_tokens * self.config.moe_router_topk @@ -1069,12 +1083,16 @@ def dispatch( pad_multiple=self.pad_multiple, ) ) + if self.moe_expert_rank_capacity_factor is not None: + over_budget = self.handle[8] != 0 # this is overflow_flag + self.over_budget |= over_budget - if not self.drop_and_pad: - self.tokens_per_expert = tokens_per_expert + if self.num_permuted_tokens is None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) # self.num_permuted_tokens is necessary to allocate the output tensor for permute self.num_permuted_tokens = self.tokens_per_expert.sum() - + if self.moe_expert_rank_capacity_factor is not None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) return dispatched_hidden def combine( @@ -1429,6 +1447,7 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - .expand(-1, -1, self.tp_size, -1) .reshape(num_local_tokens, world_size, self.num_local_experts) ).contiguous() + return routing_map, probs @jit_fuser @@ -1549,3 +1568,15 @@ def combine_postprocess(self, hidden_states: torch.Tensor): The final MoE layer output reshaped to its original dimensions. """ return hidden_states.view(self.hidden_shape) + + def check_over_budget(self): + """Check if the dispatcher has exceeded its budget.""" + if hasattr(self._comm_manager, 'over_budget'): + return self._comm_manager.over_budget + else: + return None + + def reset_over_budget(self): + """Reset the accumulated over-budget flag on the communication manager.""" + if hasattr(self._comm_manager, 'over_budget'): + self._comm_manager.over_budget.fill_(0) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d7f70aa7e07..ba5018a945e 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -690,9 +690,10 @@ def process_mtp_loss( mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits) mtp_loss = loss_mask * mtp_loss if is_training: + # Safe divide without sync: mask numerator when num_tokens==0, divide by clamp(min=1) mtp_loss_for_log = ( - torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0) - ) + torch.sum(mtp_loss) * (num_tokens > 0).to(mtp_loss.dtype) + ) / num_tokens.clamp(min=1) MTPLossLoggingHelper.save_loss_to_tracker( mtp_loss_for_log, mtp_layer_number, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 3dcc96825fa..88da4bec19b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -825,10 +825,14 @@ class TransformerConfig(ModelParallelConfig): block interleaved format. Instead of interpreting the input tensor as a concatenation of gates and linear units, it will be interpreted as alternating blocks of gates and linear units. - This data format is experimental and primarily intended to enable advanced fused kernels.""" + moe_expert_rank_capacity_factor: Optional[float] = None + """moe_expert_rank_capacity_factor (float): The capacity factor for each expert rank. Tokens + exceeding this budget will be dropped. None means no token will be dropped. + The default is None.""" + ################## # Context Parallel ################## @@ -1089,6 +1093,20 @@ class TransformerConfig(ModelParallelConfig): activation_offload_fraction: float = 1.0 """The fraction of the activation to be offloaded, which should be in range [0, 1].""" + + moe_paged_stash: bool = False + """If True, enable paged stash for all routed-expert activations needed for backward""" + + moe_paged_stash_page_size: int = 64 + """Number of tokens per page for paged stash memory management.""" + + moe_paged_stash_buffer_size_factor_cuda: float = 1.10 + """Scale factor for paged stash CUDA buffer allocation. Sign selects sizing: positive = avg-based, + negative = actual-max. Magnitude is headroom (e.g. 1.10 = 10%).""" + + moe_paged_stash_buffer_size_factor_cpu: float = 0.0 + """Scale factor for paged stash host buffer. 0 disables host buffer. Same sign convention as + moe_paged_stash_buffer_size_factor_cuda: positive = avg-based, negative = actual-max; scale = abs(factor).""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -1383,6 +1401,18 @@ def __post_init__(self): "moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity" ) + if self.moe_expert_rank_capacity_factor is not None: + if not self.use_transformer_engine_op_fuser: + raise ValueError( + "moe_expert_rank_capacity_factor requires use_transformer_engine_op_fuser to " + "be enabled." + ) + if self.moe_flex_dispatcher_backend != "hybridep": + raise ValueError( + "moe_expert_rank_capacity_factor requires moe_flex_dispatcher_backend to be " + "'hybridep'." + ) + if self.cpu_offloading and ( self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers ): @@ -1624,6 +1654,20 @@ def __post_init__(self): assert ( self.delta_offload_bytes_across_pp_ranks >= 0 ), "delta_offload_bytes_across_pp_ranks must be non-negative." + if self.moe_paged_stash: + assert not self.cpu_offloading, ( + "moe_paged_stash cannot be enabled with cpu_offloading." + ) + assert self.moe_expert_rank_capacity_factor is not None, ( + "moe_paged_stash requires moe_expert_rank_capacity_factor to be set; " + "there is no need to use paged stashing without it." + ) + moe_offload_conflict = {"expert_fc1", "moe_act"} & set(self.offload_modules) + assert not moe_offload_conflict, ( + "When moe_paged_stash is enabled, offload_modules must not include " + f"expert_fc1 or moe_act (paged stash covers those activations). " + f"Remove: {moe_offload_conflict}" + ) if ( self.num_layers_in_first_pipeline_stage is not None @@ -2307,14 +2351,15 @@ def __post_init__(self): ) if self.cuda_graph_impl != "none": - assert ( - self.cuda_graph_impl == "transformer_engine" - and CudaGraphScope.moe not in self.cuda_graph_scope - and CudaGraphScope.mlp not in self.cuda_graph_scope - ), ( - 'CUDA graph scope on moe and mlp is not ' - 'supported with overlap_moe_expert_parallel_comm' - ) + if self.cuda_graph_impl == "transformer_engine": + assert ( + self.cuda_graph_impl == "transformer_engine" + and CudaGraphScope.moe not in self.cuda_graph_scope + and CudaGraphScope.mlp not in self.cuda_graph_scope + ), ( + 'CUDA graph scope on moe and mlp is not ' + 'supported with overlap_moe_expert_parallel_comm' + ) # Check delay_wgrad_compute compatibility if self.delay_wgrad_compute: diff --git a/megatron/training/training.py b/megatron/training/training.py index bbb029e79da..b01320b2179 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -159,6 +159,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.core.transformer.cuda_graphs import TECudaGraphHelper from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.moe.paged_stash import PagedStashRunner from megatron.core.distributed import DistributedDataParallelConfig, TorchFullyShardedDataParallelConfig from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP @@ -1381,7 +1382,7 @@ def build_model(): pre_process=pre_process, post_process=post_process, vp_stage=i, - config=config, + config=config if i==0 else get_model_config(model[0]), pg_collection=pg_collection, ) this_model.model_type = model_type @@ -2863,7 +2864,20 @@ def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + ) + # Wrap forward_backward_func for overflow handling with moe_expert_rank_capacity_factor + if args.moe_expert_rank_capacity_factor is not None: + copy_main_params = args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather + forward_backward_func = PagedStashRunner( + config, + copy_main_params, + model, + optimizer, + forward_backward_func, + ) def get_e2e_base_metrics(): """Get base metrics values for one-logger to calculate E2E tracking metrics.""" @@ -3366,7 +3380,20 @@ def evaluate( eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + ) + # Wrap forward_backward_func for overflow handling with moe_expert_rank_capacity_factor + if args.moe_expert_rank_capacity_factor is not None: + copy_main_params = args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather + forward_backward_func = PagedStashRunner( + config, + copy_main_params, + model, + None, + forward_backward_func, + ) if has_nvidia_modelopt: # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors diff --git a/megatron/training/utils.py b/megatron/training/utils.py index ba470f165ec..8acfe1a169f 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -568,7 +568,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) _broadcast(n_tensor) if n == 0: diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py new file mode 100644 index 00000000000..a424b140f6f --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -0,0 +1,396 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch +import torch.nn.functional as F + +from megatron.core import config +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_utils import get_align_size_for_quantization +from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.moe.paged_stash import ( + check_paged_stash_overflow, + paged_stash_init_chunk_handler, + paged_stash_reset, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training.initialize import _set_random_seed +from tests.unit_tests.test_utilities import Utils + + +def _global_tokens_per_expert_from_local_routing_map(routing_map: torch.Tensor) -> torch.Tensor: + """Per-expert token counts from a local routing map, summed across the default process group. + + ``routing_map`` is shaped [num_local_token_rows, num_experts] (as in + ``_HybridEPManager``). Tests here assume world size equals expert-parallel size (all GPUs + are EP ranks); ``all_reduce`` on the world group aggregates disjoint local maps. + """ + counts = routing_map.sum(dim=0).to(torch.int64) + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.all_reduce(counts, op=torch.distributed.ReduceOp.SUM) + return counts + + +def _tokens_per_expert_from_routing_map(routing_map: torch.Tensor, layer: MoELayer) -> torch.Tensor: + """Per-local-expert assignment counts from the routing map (columns for this EP rank).""" + counts = _global_tokens_per_expert_from_local_routing_map(routing_map) + idx = torch.as_tensor(layer.local_expert_indices, device=counts.device, dtype=torch.long) + return counts[idx].to(torch.int64).clone() + + +def _pad_token_counts_to_align_size( + tokens_per_expert: torch.Tensor, pad_multiple: int +) -> torch.Tensor: + """Round each count up to a multiple of ``pad_multiple`` (``n + (-n % m)`` like budget).""" + t = tokens_per_expert.to(torch.int64) + return t + (-t % pad_multiple) + + +class MoEModelTestContainer: + def __init__( + self, + tp_size, + ep_size, + pp_size, + cp_size=1, + moe_tp_size=None, + data_parallel_random_init=False, + num_moe_experts=8, + num_layers=1, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_expert_capacity_factor=None, + moe_pad_expert_input_to_capacity=False, + moe_aux_loss_coeff=0.1, + test_dtype=torch.float32, + **kwargs, + ): + self.num_local_experts = num_moe_experts // ep_size + self.num_layers = num_layers + self.test_dtype = test_dtype + if moe_tp_size is None: + moe_tp_size = tp_size + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + ) + _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) + self.config = TransformerConfig( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + fp8='e4m3', + fp8_recipe='mxfp8', + fp8_wgrad=True, + fp8_amax_compute_algo='most_recent', + fp8_amax_history_len=1, + fp8_interval=1, + fp8_margin=0, + moe_router_topk=moe_router_topk, + num_moe_experts=num_moe_experts, + moe_router_load_balancing_type=moe_router_load_balancing_type, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_expert_capacity_factor=moe_expert_capacity_factor, + moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, + moe_aux_loss_coeff=moe_aux_loss_coeff, + num_layers=num_layers, + moe_router_dtype="fp32", + hidden_size=kwargs.get("hidden_size", 16), + num_attention_heads=kwargs.get("num_attention_heads", 8), + use_cpu_initialization=kwargs.get("use_cpu_initialization", True), + sequence_parallel=tp_size > 1, + add_bias_linear=kwargs.get("add_bias_linear", False), + moe_permute_fusion=kwargs.get("moe_permute_fusion", False), + moe_flex_dispatcher_backend=kwargs.get("moe_flex_dispatcher_backend", None), + moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), + moe_paged_stash=kwargs.get("moe_paged_stash", False), + moe_expert_rank_capacity_factor=kwargs.get("moe_expert_rank_capacity_factor", None), + moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True), + use_transformer_engine_op_fuser=kwargs.get("use_transformer_engine_op_fuser", False), + moe_mlp_glu_interleave_size=kwargs.get("moe_mlp_glu_interleave_size", None), + moe_router_padding_for_quantization=kwargs.get( + "moe_router_padding_for_quantization", False + ), + gated_linear_unit=kwargs.get("gated_linear_unit", False), + activation_func=kwargs.get("activation_func", F.gelu), + moe_router_force_biased=kwargs.get("moe_router_force_biased", None), + moe_paged_stash_buffer_size_factor_cuda=0.5, + moe_paged_stash_buffer_size_factor_cpu=1.5, + ) + self.moe_layers = [ + self._create_moe_layer(layer_number=i) for i in range(num_layers) + ] + self.moe_layer = self.moe_layers[0] + + def _create_moe_layer(self, layer_number=0): + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.config.num_moe_experts, moe_grouped_gemm=True + ) + quantization_context = get_fp8_context(self.config, layer_number, is_init=True) + with quantization_context: + moe_layer = ( + MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules) + .cuda() + .to(dtype=self.test_dtype) + ) + moe_layer.set_layer_number(layer_number) + return moe_layer + + def zero_grad(self): + for layer in self.moe_layers: + layer.zero_grad() + + def __del__(self): + torch.distributed.barrier() + torch.cuda.synchronize() + Utils.destroy_model_parallel() + + def destroy(self): + Utils.destroy_model_parallel() + + +def _forward_backward_all_layers(container: MoEModelTestContainer, hidden_states: torch.Tensor): + """Forward/backward all MoE layers; returns output, input grad, last layer routing state.""" + initial_hidden_states = hidden_states.cuda().requires_grad_(True) + hidden_states = initial_hidden_states + quantization_context = get_fp8_context(container.config) + with quantization_context: + for layer in container.moe_layers: + hidden_states, _ = layer(hidden_states) + output = hidden_states + last_layer = container.moe_layers[-1] + comm = getattr(last_layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) + tokens_per_expert = ( + comm.get_number_of_tokens_per_expert() + if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert") + else None + ) + output.backward(torch.ones_like(output)) + return ( + output.detach(), + initial_hidden_states.grad, + routing_map, + tokens_per_expert, + ) + + +def is_hybrid_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP + return HAVE_HYBRIDEP + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestPagedStashing: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_forward_backward_4_layers(self): + """Test paged stashing with 4 MoE layers: ref run vs paged run match.""" + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=4, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + moe_expert_rank_capacity_factor=1.5, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, + ) + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + seq_length = 1024 + batch_size = 1 + hidden_size = container.config.hidden_size + hidden_states = torch.randn( + (seq_length, batch_size, hidden_size), dtype=torch.bfloat16 + ) + + # First iteration: capture schedule, capacity, etc. + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + output_ref, hidden_states_grad_ref, routing_map_ref, tokens_per_expert_ref = ( + _forward_backward_all_layers(container, hidden_states) + ) + + container.zero_grad() + + # Second iteration: run with paged stash. + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + output, hidden_states_grad, routing_map, tokens_per_expert = _forward_backward_all_layers( + container, hidden_states + ) + + overflow = check_paged_stash_overflow() + assert overflow.any().item() == 0 + + assert torch.allclose(output, output_ref, atol=1e-4, rtol=1e-4), ( + f"output != output_ref: max diff = {(output - output_ref).abs().max().item()}" + ) + assert torch.allclose(hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4), ( + f"hidden_states_grad != ref: max diff = " + f"{(hidden_states_grad - hidden_states_grad_ref).abs().max().item()}" + ) + if routing_map is not None and tokens_per_expert is not None: + num_tokens_per_ep_rank = tokens_per_expert.sum().item() + assert num_tokens_per_ep_rank > 0, ( + f"num_tokens_per_ep_rank={num_tokens_per_ep_rank} (expected > 0)" + ) + assert routing_map_ref is not None and tokens_per_expert_ref is not None + tpe_f = tokens_per_expert.float() + ref_f = tokens_per_expert_ref.float() + assert torch.allclose(tpe_f, ref_f, atol=1e-4, rtol=1e-4), ( + f"tokens_per_expert != ref: max diff = {(tpe_f - ref_f).abs().max().item()}" + ) + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestPagedStashingOverBudget: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_overload_factor_and_over_budget(self): + """Budget matches HybridEP setup_metadata; over_budget matches map-derived load.""" + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=4, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + moe_expert_rank_capacity_factor=1.5, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, + moe_router_force_biased=1, + ) + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + seq_length = 1024 + batch_size = 1 + topk = container.config.moe_router_topk + capacity_factor = container.config.moe_expert_rank_capacity_factor + hidden_states = torch.randn( + (seq_length, batch_size, container.config.hidden_size), dtype=torch.bfloat16 + ) + + num_tokens = seq_length * batch_size * topk + pad_multiple = get_align_size_for_quantization(container.config) + budget = int(num_tokens * capacity_factor) + budget += -budget % pad_multiple + + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + _forward_backward_all_layers(container, hidden_states) + + overflow = check_paged_stash_overflow() + num_layers = len(container.moe_layers) + stash_cuda = container.config.moe_paged_stash_buffer_size_factor_cuda + stash_cpu = container.config.moe_paged_stash_buffer_size_factor_cpu + stash_buffer_size = num_tokens * num_layers * (stash_cuda + stash_cpu) + + total_tokens = 0 + for layer_idx, layer in enumerate(container.moe_layers): + comm = getattr(layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) if comm is not None else None + over_budget_tensor = ( + layer.token_dispatcher.check_over_budget() + if hasattr(layer.token_dispatcher, "check_over_budget") + else None + ) + over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False + + assert routing_map is not None, f"layer {layer_idx}: routing_map is None" + assert routing_map.dim() == 2, f"layer {layer_idx}: expected 2D routing_map" + assert routing_map.shape[1] == container.config.num_moe_experts, ( + f"layer {layer_idx}: routing_map has {routing_map.shape[1]} experts, " + f"expected {container.config.num_moe_experts}" + ) + tokens_per_expert_from_map = _tokens_per_expert_from_routing_map(routing_map, layer) + tokens_per_expert_from_map_padded = _pad_token_counts_to_align_size( + tokens_per_expert_from_map, pad_multiple + ) + tokens_per_ep_rank_from_map = tokens_per_expert_from_map_padded.sum().item() + total_tokens += tokens_per_ep_rank_from_map + + # Padded map-derived tokens strictly over budget iff dispatcher reports over_budget + if tokens_per_ep_rank_from_map > budget: + assert over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) > budget ({budget}), " + f"but over_budget flag was not set" + ) + else: + assert not over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) <= budget ({budget}), " + f"but over_budget flag was set" + ) + + overflow_set = overflow.any().item() + stash_exceeded = total_tokens > stash_buffer_size + assert overflow_set == stash_exceeded, ( + f"overflow {overflow_set} should match total_tokens > stash_buffer_size " + f"({total_tokens} > {stash_buffer_size})" + )