Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
1f62ecc
Add --moe-use-device-initiated-grouped-gemm to allow token_per_expert…
QiZhangNV Nov 3, 2025
c7df801
Initial change for packed offloading
vasunvidia Nov 17, 2025
071f996
Bug fix
Nov 17, 2025
35cf0c4
Mem Opt
vasunvidia Nov 17, 2025
1926269
Handle MXFP8Tensor offload
Nov 20, 2025
8508f89
Enable Packed offloading to CPU pinned memory with PACKED_OFFLOAD_CPU=1
Nov 20, 2025
1f4c972
Enable activation truncation for first step
Nov 21, 2025
77c6d0f
Overflow check and assert
Nov 22, 2025
b3e4fca
Check in temporary solution for detecing overflow in receiving buffer
nanz-nv Nov 22, 2025
f536797
Reconstruct the stash buffer into a 2D structure
nanz-nv Nov 23, 2025
7775e04
Refactor the code to check overflow in HybridEP receiving buffer
nanz-nv Nov 24, 2025
e1542c2
Use CPU offloading context manager as a WAR for now to WAR the proble…
nanz-nv Nov 24, 2025
7534792
Add support for paged stashing
nanz-nv Nov 25, 2025
c053801
Add the feature of speculative CE stashing
nanz-nv Nov 26, 2025
beec6f5
Fix PP schedule
Nov 26, 2025
d6bf23c
Use common buffer across VP for paged stashing
vasunvidia Nov 26, 2025
f364bed
Disable Packed Offloading for validation
Nov 27, 2025
021394c
Fixe perf issue in packed stash/pop kernels
nanz-nv Nov 27, 2025
59b1ae5
Minor fix for tensor allocation and padding requirement on budget
nanz-nv Dec 7, 2025
0fc192b
Packed/paged offloading is current not stream-safe. Need to put stash…
nanz-nv Dec 7, 2025
14212f5
add new hybrid ep
Autumn1998 Dec 9, 2025
b26b52d
Remove the overflow check in framework because it is now done by hybr…
nanz-nv Dec 10, 2025
0fd67d2
Fix one merge conflict
nanz-nv Dec 10, 2025
1fa9bcc
Code cleanup
vasunvidia Dec 11, 2025
249a0e0
Add second autograd to avoid triple buffering
vasunvidia Dec 12, 2025
1551704
Avoid unnecessary wait_stream for reload in case of 1f1b
vasunvidia Dec 12, 2025
8528208
Check in dynamic-shape-aware SwiGLU triton kernel
nanz-nv Dec 18, 2025
8eed19f
Major cleanup and refactor
nanz-nv Dec 18, 2025
5697300
Check in paged_stash.py that was omited in the previous commit
nanz-nv Dec 18, 2025
c7481cc
Remove d2d page feature for now
nanz-nv Dec 18, 2025
f3714b2
Update added arguments and add compatibility check
nanz-nv Dec 18, 2025
603af47
refine overflow check
nanz-nv Dec 18, 2025
519a609
Fixing lint issues
nanz-nv Dec 19, 2025
3dfdb75
Minor refactor
vasunvidia Jan 8, 2026
89d2ddd
Add unit test for Paged Stashing
vasunvidia Jan 9, 2026
675a175
1. allocate stashing buffer based on avg token count if STASH_BUFFER_…
nanz-nv Jan 22, 2026
4aa7eab
Reenable overlapping of stashing kernels
nanz-nv Jan 23, 2026
69ca640
Remove a buggy/redundant reset
nanz-nv Feb 3, 2026
1d59600
Cleanup moe-expert-rank-capacity-factor argument.
vasunvidia Feb 9, 2026
c4063a5
Update moe_use_device_initiated_grouped_gemm check for paged stashing…
vasunvidia Feb 21, 2026
8cd03d8
Remove the WAR of running warmup on a side stream
nanz-nv Mar 17, 2026
3cf56c3
Fix for data_iterator type check in Paged Stashing fallback
vasunvidia Mar 18, 2026
bde2bd7
Change to support eager-mode fallback for validation
vasunvidia Mar 18, 2026
1cb8f3c
Revert "Check in dynamic-shape-aware SwiGLU triton kernel"
nanz-nv Mar 18, 2026
afd0280
Fixed some minor issues
nanz-nv Mar 18, 2026
414adda
Fix the unit test
nanz-nv Mar 18, 2026
7a74dc2
Initial commit for spill to cpu feature
nanz-nv Mar 14, 2026
6505ba5
Move paged stashing knobs from env vars to transformer_config knobs
nanz-nv Mar 18, 2026
446e304
Refactor the knobs a bit so it is more intuitive
nanz-nv Mar 18, 2026
28c85d7
Use get_attr_wrapped_model util to access moe and mtp layers
vasunvidia Mar 18, 2026
ce58cb9
Refactor the unit test for paged stashing
nanz-nv Mar 20, 2026
2ce025b
Clean up after rebase
nanz-nv Mar 21, 2026
e826a84
Refactor/clean-up logging
nanz-nv Mar 25, 2026
7af01fc
Resolve review feedback
nanz-nv Mar 25, 2026
683dbd1
Fix fallback data read for PP=1
vasunvidia Mar 25, 2026
f7c2755
Paged stashing refactor
vasunvidia Mar 26, 2026
09579aa
Remove logical_shape check
vasunvidia Mar 26, 2026
f4f5410
Remove paged_stash_set_last_layer
vasunvidia Mar 26, 2026
4f8330d
Cleanup PadUnpadFunction
vasunvidia Mar 26, 2026
7b63934
Remove stash modules and remove stashing code for non-fused grouped gemm
nanz-nv Mar 30, 2026
c8d2c5f
Remove dead code
nanz-nv Mar 30, 2026
9b5796c
Fix TE import problem in experts.py
nanz-nv Mar 31, 2026
e7e8349
Fixed merge conflict
nanz-nv Mar 31, 2026
43a68ac
Address reviewer's comments
nanz-nv Mar 31, 2026
9f3fe06
Review comments
vasunvidia Apr 2, 2026
3c1d374
Add PagedStashRunner for overflow detection for pure M-LM training
vasunvidia Apr 2, 2026
547817e
Release stashing buffer before fallback to restore the memory
nanz-nv Apr 3, 2026
d0cff43
Fix an issue with PagedStashRunner
nanz-nv Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions megatron/core/full_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import gc
import torch

from megatron.core.tensor_parallel.random import get_all_rng_states
Expand Down Expand Up @@ -180,12 +181,10 @@ def __call__(self, *args, **kwargs):
torch.cuda.synchronize()
torch.distributed.barrier()
logger.info(f'CUDA graph capture done for {training_str}!!!')

if FullCudaGraphWrapper.cuda_graph[training_str] is None:
FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs)
else:
FullCudaGraphWrapper.cuda_graph[training_str].replay()

self.next_iter(training_str)
return FullCudaGraphWrapper.result[training_str]

Expand All @@ -196,3 +195,19 @@ def curr_iter(self, stage):
def next_iter(self, stage):
"""Increment current training/validation iteration."""
FullCudaGraphWrapper.curr_iteration[stage] += 1

def reset_cuda_graph(self, stage=None):
"""Reset CUDA graph."""
if stage is None or stage == 'training':
if FullCudaGraphWrapper.cuda_graph['training'] is not None:
del FullCudaGraphWrapper.cuda_graph['training']
FullCudaGraphWrapper.cuda_graph['training'] = None
FullCudaGraphWrapper.result['training'] = None
FullCudaGraphWrapper.curr_iteration['training'] = 0
if stage is None or stage == 'validation':
if FullCudaGraphWrapper.cuda_graph['validation'] is not None:
del FullCudaGraphWrapper.cuda_graph['validation']
FullCudaGraphWrapper.cuda_graph['validation'] = None
FullCudaGraphWrapper.result['validation'] = None
FullCudaGraphWrapper.curr_iteration['validation'] = 0
gc.collect()
12 changes: 12 additions & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
Expand Down Expand Up @@ -483,6 +484,12 @@ def preprocess_for_fine_grained_offloading(self):
off_interface.mark_not_offload(param)
self.disable_param_offloading = False

def preprocess_for_paged_stash(self):
"""Preprocess for paged stash."""
return paged_stash_init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage
)

def forward(
self,
input_ids: Tensor,
Expand Down Expand Up @@ -519,6 +526,9 @@ def forward(
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()

if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

inference_context = deprecate_inference_params(inference_context, inference_params)

preproc_output = self._preprocess(
Expand Down Expand Up @@ -823,6 +833,8 @@ def build_schedule_plan(

if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan

Expand Down
10 changes: 10 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
from megatron.core.transformer.moe.paged_stash import paged_stash_reset
from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
Expand Down Expand Up @@ -637,6 +638,9 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
Expand Down Expand Up @@ -1063,6 +1067,9 @@ def forward_backward_pipelining_with_interleaving(
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

Expand Down Expand Up @@ -2265,6 +2272,9 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
Expand Down
46 changes: 36 additions & 10 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from contextlib import nullcontext
from functools import partial
from itertools import chain
from math import ceil
Expand Down Expand Up @@ -43,6 +44,11 @@
get_align_size_for_quantization,
skip_routed_expert_padding,
)
from megatron.core.transformer.moe.paged_stash import (
get_paged_stash_context,
paged_stash_group_commit,
paged_stash_group_start,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
ensure_metadata_has_dp_cp_group,
Expand All @@ -53,6 +59,7 @@

if HAVE_TE:
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
import transformer_engine as te
else:
Fp8Padding, Fp8Unpadding = None, None

Expand Down Expand Up @@ -942,19 +949,38 @@ def _fused_forward(
tokens_per_expert = torch.tensor(
tokens_per_expert, dtype=torch.int, device=permuted_probs.device
)

# Call fused impl
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU
tokens_per_expert, # FC2
)

# if the number of tokens is 0, pad the hidden states to 256

if self.config.moe_paged_stash:
permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states)
max_num_tokens = permuted_local_hidden_states.shape[0]
# Average/expected tokens is a pre-padding estimate used by paged stashing heuristics.
# moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled.
cap_factor = self.config.moe_expert_rank_capacity_factor
avg_num_tokens = (
int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None
)
stash_context = get_paged_stash_context(
name="grouped_mlp",
max_num_tokens=max_num_tokens,
num_tokens_tensor=tokens_per_expert.sum(),
avg_num_tokens=avg_num_tokens,
)
else:
stash_context = nullcontext()
with stash_context:
# Call fused impl
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU
tokens_per_expert, # FC2
)
# Remove padding if needed
if unpadded_tokens_per_expert is not None:
output = self.quantization_unpadding(output, unpadded_tokens_per_expert)

if self.config.moe_paged_stash:
output = paged_stash_group_commit(output, name="grouped_mlp")
return output

def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs):
Expand Down
Loading