Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7da860a
Add flexible virtual pipeline parallel (fVPP)
duncanriach Feb 11, 2026
c07a8ce
Fix bug for PP=1
duncanriach Feb 13, 2026
1873e40
Fix bug related to MTP being on final PP stage
duncanriach Feb 13, 2026
7b30eca
Improve comment
duncanriach Feb 13, 2026
1800757
Fixed linting errors
duncanriach Feb 13, 2026
1317202
Fix isort error
duncanriach Feb 13, 2026
95e5391
Fix test
duncanriach Feb 13, 2026
80835b9
Fix tests
duncanriach Feb 13, 2026
be9f823
Fix progress_groups test
duncanriach Feb 17, 2026
749de5c
Still able to use --num-layers with --hybrid-override-pattern
duncanriach Feb 17, 2026
94771ad
Add better deprecation support to MambaModel
duncanriach Feb 17, 2026
3e5d367
Make new style consistent with old
duncanriach Feb 17, 2026
026009e
Fix linting
duncanriach Feb 17, 2026
2642d00
Fix ModelOpt builder
duncanriach Feb 17, 2026
5faf149
More fixes for ModelOpt
duncanriach Feb 17, 2026
2c1feff
Re-enable hybrid pp4 test
duncanriach Feb 18, 2026
91a98ae
Added backwards compatibility support for attention and mlp ratios
duncanriach Feb 18, 2026
41c8c7b
Fix MambaModel parameter deprecation code
duncanriach Feb 18, 2026
0bc2037
Fix functional test loader to handle pipe symbol
duncanriach Feb 18, 2026
dab21de
Fix test
duncanriach Feb 18, 2026
2c2373e
Update golden values for hybrid pp=4 test
duncanriach Feb 19, 2026
19d7e1a
Improve functional test handling of special symbols in layer patterns
duncanriach Feb 19, 2026
92c1cba
Add hybrid model functional test for pp2vpp2
duncanriach Feb 20, 2026
baf881b
Add golden values for hybrid model pp2vpp2 functional test
duncanriach Feb 20, 2026
1ca7c76
Fix linting
duncanriach Feb 21, 2026
ca8558e
Fix loading of legacy hybriid model checkpoints
duncanriach Feb 21, 2026
0f5d3ef
Cleanup previous commit
duncanriach Feb 21, 2026
1f49a66
Improve comment
duncanriach Feb 21, 2026
415c973
Merge branch 'main' into add-hybrid-fvpp-to-main-v1
ko3n1g Feb 23, 2026
1401de1
Enable different inference pipelining of pre-trained model
duncanriach Feb 24, 2026
ca74544
Improve comment
duncanriach Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def model_provider(
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base,
language_rope_scaling=args.use_rope_scaling,
hybrid_attention_ratio=args.hybrid_attention_ratio,
hybrid_mlp_ratio=args.hybrid_mlp_ratio,
hybrid_override_pattern=args.hybrid_override_pattern,
hybrid_layer_pattern=args.hybrid_layer_pattern,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
image_token_index=image_token_index,
pixel_shuffle=args.pixel_shuffle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@ MODEL_ARGS=" \
\
--attention-backend flash \
--disable-gloo-process-groups \
--is-hybrid-model \
--mamba-num-heads 64 \
--mamba-head-dim 64 \
--hybrid-override-pattern MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME \
--hybrid-layer-pattern MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME \
--use-mcore-models \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--init-method-std 0.0173 \
--position-embedding-type none \
--squared-relu \
--num-layers 52 \
--hidden-size 2688 \
--num-attention-heads 32 \
--group-query-attention \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ MODEL_ARGS=" \
--no-rope-fusion \
--normalization RMSNorm \
--squared-relu \
--num-layers 56 \
--hidden-size 4480 \
--ffn-hidden-size 15680 \
--num-attention-heads 40 \
--kv-channels 128 \
--group-query-attention \
--num-query-groups 8 \
--hybrid-override-pattern M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M- \
--is-hybrid-model \
--hybrid-layer-pattern M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M- \
--mamba-head-dim 80 \
--mamba-num-heads 128 \
--mamba-num-groups 8 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ MODEL_ARGS=" \
--position-embedding-type none \
--normalization RMSNorm \
--squared-relu \
--num-layers 98 \
--hidden-size 8192 \
--ffn-hidden-size 30720 \
--num-attention-heads 64 \
--kv-channels 128 \
--group-query-attention \
--num-query-groups 8 \
--hybrid-override-pattern M-M-M-M-M-M-M-M-M*-M-M-M-M-M-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-M-M---MM---M-M*-M-M-M-M-M- \
--is-hybrid-model \
--hybrid-layer-pattern M-M-M-M-M-M-M-M-M*-M-M-M-M-M-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-M-M---MM---M-M*-M-M-M-M-M- \
--mamba-head-dim 64 \
--mamba-num-heads 256 \
--mamba-num-groups 8 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ MODEL_ARGS=" \
--no-position-embedding \
--normalization RMSNorm \
--squared-relu \
--num-layers 52 \
--hidden-size 3072 \
--ffn-hidden-size 12288 \
--kv-channels 128 \
--num-attention-heads 32 \
--group-query-attention \
--num-query-groups 8 \
--hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--mamba-head-dim 64 \
--mamba-num-heads 112 \
--mamba-num-groups 8 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ MODEL_ARGS=" \
--save-interval 100000 \
--micro-batch-size 1 \
--attention-backend flash \
--is-hybrid-model \
--hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--mamba-state-dim 256 \
--tiktoken-pattern v2 \
--use-mcore-models \
Expand All @@ -22,7 +21,6 @@ MODEL_ARGS=" \
--init-method-std 0.0099 \
--position-embedding-type none \
--squared-relu \
--num-layers 118 \
--hidden-size 8192 \
--num-attention-heads 64 \
--group-query-attention \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ MODEL_ARGS=" \
--no-position-embedding \
--normalization RMSNorm \
--squared-relu \
--num-layers 52 \
--hidden-size 4096 \
--ffn-hidden-size 21504 \
--num-attention-heads 32 \
--group-query-attention \
--num-query-groups 8 \
--hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--is-hybrid-model \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--mamba-head-dim 64 \
--mamba-num-heads 128 \
--mamba-num-groups 8 \
Expand Down
3 changes: 1 addition & 2 deletions examples/rl/model_configs/nemotron5_56b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ MODEL_OPTIONS="\
--first-last-layers-bf16 \
\
--fp8-recipe tensorwise \
--hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \
--mamba-state-dim 256 \
--per-split-data-args-path ${BLEND_PATH} \
Expand All @@ -82,7 +82,6 @@ MODEL_OPTIONS="\
--init-method-std 0.0099 \
--position-embedding-type none \
--squared-relu \
--num-layers 118 \
--hidden-size 8192 \
--num-attention-heads 64 \
--group-query-attention \
Expand Down
3 changes: 1 addition & 2 deletions examples/rl/model_configs/nemotron5_8b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ MODEL_OPTIONS="\
--inference-max-seq-length $MAX_SEQ_LENGTH \
--inference-max-requests $MAX_INFERENCE_BS \
--pretrained-checkpoint $CHECKPOINT \
--hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \
--tiktoken-pattern v2 \
--distributed-timeout-minutes 60 \
Expand All @@ -73,7 +73,6 @@ MODEL_OPTIONS="\
--init-method-std 0.014 \
--position-embedding-type none \
--squared-relu \
--num-layers 52 \
--hidden-size 4096 \
--num-attention-heads 32 \
--group-query-attention \
Expand Down
4 changes: 1 addition & 3 deletions examples/rl/model_configs/nemotron5p5_12b_H.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ MODEL_OPTIONS="\
--num-layers-at-end-in-bf16 2 \
--fp8-param-gather \
--disable-gloo-process-groups \
--is-hybrid-model \
--mamba-head-dim 80 \
--hybrid-override-pattern M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- \
--hybrid-layer-pattern M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- \
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \
--tiktoken-pattern v2 \
--distributed-timeout-minutes 10 \
Expand All @@ -89,7 +88,6 @@ MODEL_OPTIONS="\
--init-method-std 0.0125 \
--position-embedding-type none \
--squared-relu \
--num-layers 62 \
--hidden-size 5120 \
--num-attention-heads 40 \
--group-query-attention \
Expand Down
4 changes: 1 addition & 3 deletions mamba_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, p
mamba_stack_spec=mamba_stack_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
hybrid_layer_pattern=args.hybrid_layer_pattern,
pre_process=pre_process,
hybrid_attention_ratio=args.hybrid_attention_ratio,
hybrid_mlp_ratio=args.hybrid_mlp_ratio,
hybrid_override_pattern=args.hybrid_override_pattern,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
Expand Down
92 changes: 73 additions & 19 deletions megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import warnings
from typing import Literal, Optional

from torch import Tensor
Expand Down Expand Up @@ -38,16 +39,25 @@ class MambaModel(LanguageModule):
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence.
This is used for positional embedding
pre_process (bool, optional): Include embedding layer
(used with pipeline parallelism). Defaults to True.
hybrid_attention_ratio (float, optional): The target ratio of attention
layers to total layers
hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional): Unified hybrid layer pattern with optional MTP.
hybrid_layer_pattern (str): Unified hybrid layer pattern with optional MTP and
pipeline stage boundaries.
Format: "<main_pattern>/<mtp_pattern>/<mtp_pattern>/..."
The main pattern may contain "|" to define pipeline stage boundaries.
Examples:
- "M*M*" -> main decoder only, no MTP
- "M*M*/MM/MM" -> main="M*M*", mtp="MM", 2 depths
- "M-M-|M-M*-|M-M-|M-M*-" -> 4 pipeline segments
hybrid_attention_ratio (float, optional): Deprecated. Use hybrid_layer_pattern instead.
If set to a value > 0.0 and hybrid_layer_pattern is None, a pattern will be
generated from the ratio with a deprecation warning.
hybrid_mlp_ratio (float, optional): Deprecated. Use hybrid_layer_pattern instead.
If set to a value > 0.0 and hybrid_layer_pattern is None, a pattern will be
generated from the ratio with a deprecation warning.
hybrid_override_pattern (str, optional): Deprecated. Use hybrid_layer_pattern instead.
If set and hybrid_layer_pattern is None, the value is copied to hybrid_layer_pattern
with a deprecation warning.
pre_process (bool, optional): Include embedding layer
(used with pipeline parallelism). Defaults to True.
post_process (bool, optional): Include an output layer (used with pipeline parallelism).
Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
Expand All @@ -65,6 +75,7 @@ class MambaModel(LanguageModule):
interpolating RoPE for longer sequences. The value must be a float larger than 1.0.
Defaults to None.
pg_collection (ProcessGroupCollection, optional): Model communication process groups.
vp_stage (Optional[int], optional): Virtual pipeline stage index. Defaults to None.
"""

def __init__(
Expand All @@ -73,10 +84,11 @@ def __init__(
mamba_stack_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_layer_pattern: str = None,
hybrid_attention_ratio: float = None,
hybrid_mlp_ratio: float = None,
hybrid_override_pattern: str = None,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
Expand All @@ -98,29 +110,72 @@ def __init__(
self.mamba_stack_spec: ModuleSpec = mamba_stack_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.hybrid_layer_pattern = hybrid_layer_pattern
self.pre_process = pre_process
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
self.vp_stage = vp_stage

# Parse unified pattern to extract main and MTP components
from megatron.core.ssm.mamba_hybrid_layer_allocation import parse_hybrid_pattern
# Backward compatibility for deprecated hybrid parameters
if hybrid_override_pattern is not None:
if hybrid_layer_pattern is None:
warnings.warn(
"hybrid_override_pattern has been deprecated. "
"Use hybrid_layer_pattern instead.",
DeprecationWarning,
stacklevel=2,
)
self.hybrid_layer_pattern = hybrid_override_pattern
else:
raise ValueError(
"hybrid_override_pattern and hybrid_layer_pattern cannot both be set. "
"hybrid_override_pattern has been deprecated; use hybrid_layer_pattern instead."
)
if (hybrid_attention_ratio is not None and hybrid_attention_ratio > 0.0) or (
hybrid_mlp_ratio is not None and hybrid_mlp_ratio > 0.0
):
warnings.warn(
"hybrid_attention_ratio and hybrid_mlp_ratio have been deprecated. "
"Use hybrid_layer_pattern instead.",
DeprecationWarning,
stacklevel=2,
)
if self.hybrid_layer_pattern is None:
from megatron.core.ssm.mamba_hybrid_layer_allocation import pattern_from_ratios

attn_ratio = hybrid_attention_ratio if hybrid_attention_ratio else 0.0
mlp_ratio = hybrid_mlp_ratio if hybrid_mlp_ratio else 0.0
self.hybrid_layer_pattern = pattern_from_ratios(
config.num_layers, attn_ratio, mlp_ratio
)

# Parse unified pattern to extract main and MTP components, and
# determine the pipeline segment for this model instance.
from megatron.core.ssm.mamba_hybrid_layer_allocation import (
parse_hybrid_pattern,
select_pipeline_segment,
)

parsed = parse_hybrid_pattern(hybrid_override_pattern)
parsed = parse_hybrid_pattern(self.hybrid_layer_pattern)
self.mtp_pattern = parsed.mtp_pattern
self.mtp_num_depths = parsed.mtp_num_depths

layer_type_list, layer_offset = select_pipeline_segment(
parsed.main_pattern or '', self.pg_collection.pp, vp_stage
)

# Determine if MTP is needed (based on pattern parsing)
self.mtp_process = (
self.mtp_pattern is not None
and self.mtp_num_depths > 0
and mtp_on_this_rank(self.config, vp_stage=self.vp_stage)
# The following forces MTP to be on the final pipeline stage. It might be more optimal
# to split the hybrid layer pattern into pipeline stages before parsing the pattern for
# the current pipeline stage. This could also enable MTP standalone (MTP in a pipeline
# stage separate from loss) to be supported in the hybrid model.
and mtp_on_this_rank(self.config, ignore_virtual=False, vp_stage=self.vp_stage)
)

# megatron core pipelining currently depends on model type
Expand Down Expand Up @@ -151,9 +206,8 @@ def __init__(
mamba_stack_spec,
self.config,
pre_process=self.pre_process,
hybrid_attention_ratio=self.hybrid_attention_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=parsed.main_pattern,
layer_type_list=layer_type_list,
pp_layer_offset=layer_offset,
post_process=self.post_process,
dtype=config.params_dtype,
pg_collection=self.pg_collection,
Expand Down
8 changes: 2 additions & 6 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def __init__(
language_rotary_base: int = 10000,
language_rope_scaling: bool = False,
language_rope_scaling_factor: float = 8.0,
hybrid_attention_ratio: float = 1.0,
hybrid_mlp_ratio: float = 1.0,
hybrid_override_pattern: str = None,
hybrid_layer_pattern: str = None,
fp16_lm_cross_entropy: bool = False,
image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
pixel_shuffle: bool = False,
Expand Down Expand Up @@ -206,9 +204,7 @@ def __init__(
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
pre_process=self.pre_process,
hybrid_attention_ratio=hybrid_attention_ratio,
hybrid_mlp_ratio=hybrid_mlp_ratio,
hybrid_override_pattern=hybrid_override_pattern,
hybrid_layer_pattern=hybrid_layer_pattern,
post_process=self.post_process,
rotary_percent=language_rotary_percent,
rotary_base=language_rotary_base,
Expand Down
Loading
Loading