diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index 212a2cda021..e98ff2df519 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -213,9 +213,7 @@ def model_provider( patch_dim=args.patch_dim, language_rotary_base=args.rotary_base, language_rope_scaling=args.use_rope_scaling, - hybrid_attention_ratio=args.hybrid_attention_ratio, - hybrid_mlp_ratio=args.hybrid_mlp_ratio, - hybrid_override_pattern=args.hybrid_override_pattern, + hybrid_layer_pattern=args.hybrid_layer_pattern, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, image_token_index=image_token_index, pixel_shuffle=args.pixel_shuffle, diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh index c294e03235c..1fa00889e99 100644 --- a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh @@ -30,17 +30,15 @@ MODEL_ARGS=" \ \ --attention-backend flash \ --disable-gloo-process-groups \ - --is-hybrid-model \ --mamba-num-heads 64 \ --mamba-head-dim 64 \ - --hybrid-override-pattern MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME \ + --hybrid-layer-pattern MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME \ --use-mcore-models \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ --init-method-std 0.0173 \ --position-embedding-type none \ --squared-relu \ - --num-layers 52 \ --hidden-size 2688 \ --num-attention-heads 32 \ --group-query-attention \ diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh index a2212483008..83867430a97 100644 --- a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh @@ -19,15 +19,13 @@ MODEL_ARGS=" \ --no-rope-fusion \ --normalization RMSNorm \ --squared-relu \ - --num-layers 56 \ --hidden-size 4480 \ --ffn-hidden-size 15680 \ --num-attention-heads 40 \ --kv-channels 128 \ --group-query-attention \ --num-query-groups 8 \ - --hybrid-override-pattern M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M- \ - --is-hybrid-model \ + --hybrid-layer-pattern M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M- \ --mamba-head-dim 80 \ --mamba-num-heads 128 \ --mamba-num-groups 8 \ diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh index ad07c1061c5..901e607f298 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh @@ -18,15 +18,13 @@ MODEL_ARGS=" \ --position-embedding-type none \ --normalization RMSNorm \ --squared-relu \ - --num-layers 98 \ --hidden-size 8192 \ --ffn-hidden-size 30720 \ --num-attention-heads 64 \ --kv-channels 128 \ --group-query-attention \ --num-query-groups 8 \ - --hybrid-override-pattern M-M-M-M-M-M-M-M-M*-M-M-M-M-M-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-M-M---MM---M-M*-M-M-M-M-M- \ - --is-hybrid-model \ + --hybrid-layer-pattern M-M-M-M-M-M-M-M-M*-M-M-M-M-M-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-M-M---MM---M-M*-M-M-M-M-M- \ --mamba-head-dim 64 \ --mamba-num-heads 256 \ --mamba-num-groups 8 \ diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh index 4ba91dbd8c6..084db49e0eb 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh @@ -21,14 +21,13 @@ MODEL_ARGS=" \ --no-position-embedding \ --normalization RMSNorm \ --squared-relu \ - --num-layers 52 \ --hidden-size 3072 \ --ffn-hidden-size 12288 \ --kv-channels 128 \ --num-attention-heads 32 \ --group-query-attention \ --num-query-groups 8 \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ --mamba-head-dim 64 \ --mamba-num-heads 112 \ --mamba-num-groups 8 \ diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh index 8377f0f11d6..645a159d075 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh @@ -12,8 +12,7 @@ MODEL_ARGS=" \ --save-interval 100000 \ --micro-batch-size 1 \ --attention-backend flash \ - --is-hybrid-model \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ --mamba-state-dim 256 \ --tiktoken-pattern v2 \ --use-mcore-models \ @@ -22,7 +21,6 @@ MODEL_ARGS=" \ --init-method-std 0.0099 \ --position-embedding-type none \ --squared-relu \ - --num-layers 118 \ --hidden-size 8192 \ --num-attention-heads 64 \ --group-query-attention \ diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh index b04bf76f360..66f3ad368b4 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh @@ -20,14 +20,12 @@ MODEL_ARGS=" \ --no-position-embedding \ --normalization RMSNorm \ --squared-relu \ - --num-layers 52 \ --hidden-size 4096 \ --ffn-hidden-size 21504 \ --num-attention-heads 32 \ --group-query-attention \ --num-query-groups 8 \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ - --is-hybrid-model \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ --mamba-head-dim 64 \ --mamba-num-heads 128 \ --mamba-num-groups 8 \ diff --git a/examples/rl/model_configs/nemotron5_56b.sh b/examples/rl/model_configs/nemotron5_56b.sh index 741cd054b73..23b9f99a72a 100644 --- a/examples/rl/model_configs/nemotron5_56b.sh +++ b/examples/rl/model_configs/nemotron5_56b.sh @@ -68,7 +68,7 @@ MODEL_OPTIONS="\ --first-last-layers-bf16 \ \ --fp8-recipe tensorwise \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ --mamba-state-dim 256 \ --per-split-data-args-path ${BLEND_PATH} \ @@ -82,7 +82,6 @@ MODEL_OPTIONS="\ --init-method-std 0.0099 \ --position-embedding-type none \ --squared-relu \ - --num-layers 118 \ --hidden-size 8192 \ --num-attention-heads 64 \ --group-query-attention \ diff --git a/examples/rl/model_configs/nemotron5_8b.sh b/examples/rl/model_configs/nemotron5_8b.sh index 753d4e493a2..c18149f03d6 100644 --- a/examples/rl/model_configs/nemotron5_8b.sh +++ b/examples/rl/model_configs/nemotron5_8b.sh @@ -60,7 +60,7 @@ MODEL_OPTIONS="\ --inference-max-seq-length $MAX_SEQ_LENGTH \ --inference-max-requests $MAX_INFERENCE_BS \ --pretrained-checkpoint $CHECKPOINT \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ --tiktoken-pattern v2 \ --distributed-timeout-minutes 60 \ @@ -73,7 +73,6 @@ MODEL_OPTIONS="\ --init-method-std 0.014 \ --position-embedding-type none \ --squared-relu \ - --num-layers 52 \ --hidden-size 4096 \ --num-attention-heads 32 \ --group-query-attention \ diff --git a/examples/rl/model_configs/nemotron5p5_12b_H.sh b/examples/rl/model_configs/nemotron5p5_12b_H.sh index adbcc8d03f0..1826d57e913 100644 --- a/examples/rl/model_configs/nemotron5p5_12b_H.sh +++ b/examples/rl/model_configs/nemotron5p5_12b_H.sh @@ -74,9 +74,8 @@ MODEL_OPTIONS="\ --num-layers-at-end-in-bf16 2 \ --fp8-param-gather \ --disable-gloo-process-groups \ - --is-hybrid-model \ --mamba-head-dim 80 \ - --hybrid-override-pattern M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- \ + --hybrid-layer-pattern M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- \ --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ --tiktoken-pattern v2 \ --distributed-timeout-minutes 10 \ @@ -89,7 +88,6 @@ MODEL_OPTIONS="\ --init-method-std 0.0125 \ --position-embedding-type none \ --squared-relu \ - --num-layers 62 \ --hidden-size 5120 \ --num-attention-heads 40 \ --group-query-attention \ diff --git a/mamba_builders.py b/mamba_builders.py index 5d31af60475..650ea4a719f 100644 --- a/mamba_builders.py +++ b/mamba_builders.py @@ -30,10 +30,8 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, p mamba_stack_spec=mamba_stack_spec, vocab_size=args.padded_vocab_size, max_sequence_length=args.max_position_embeddings, + hybrid_layer_pattern=args.hybrid_layer_pattern, pre_process=pre_process, - hybrid_attention_ratio=args.hybrid_attention_ratio, - hybrid_mlp_ratio=args.hybrid_mlp_ratio, - hybrid_override_pattern=args.hybrid_override_pattern, post_process=post_process, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, parallel_output=True, diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8dd614fdaaa..f8987f68500 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -1,5 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import logging from typing import Literal, Optional from torch import Tensor @@ -26,8 +27,11 @@ WrappedTensor, deprecate_inference_params, is_using_quantization_scales, + log_single_rank, ) +logger = logging.getLogger(__name__) + class MambaModel(LanguageModule): """Mamba language model. @@ -38,16 +42,25 @@ class MambaModel(LanguageModule): vocab_size (int): Vocabulary size max_sequence_length (int): maximum size of sequence. This is used for positional embedding - pre_process (bool, optional): Include embedding layer - (used with pipeline parallelism). Defaults to True. - hybrid_attention_ratio (float, optional): The target ratio of attention - layers to total layers - hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers - hybrid_override_pattern (str, optional): Unified hybrid layer pattern with optional MTP. + hybrid_layer_pattern (str): Unified hybrid layer pattern with optional MTP and + pipeline stage boundaries. Format: "///..." + The main pattern may contain "|" to define pipeline stage boundaries. Examples: - "M*M*" -> main decoder only, no MTP - "M*M*/MM/MM" -> main="M*M*", mtp="MM", 2 depths + - "M-M-|M-M*-|M-M-|M-M*-" -> 4 pipeline segments + hybrid_attention_ratio (float, optional): Deprecated. Use hybrid_layer_pattern instead. + If set to a value > 0.0 and hybrid_layer_pattern is None, a pattern will be + generated from the ratio with a deprecation warning. + hybrid_mlp_ratio (float, optional): Deprecated. Use hybrid_layer_pattern instead. + If set to a value > 0.0 and hybrid_layer_pattern is None, a pattern will be + generated from the ratio with a deprecation warning. + hybrid_override_pattern (str, optional): Deprecated. Use hybrid_layer_pattern instead. + If set and hybrid_layer_pattern is None, the value is copied to hybrid_layer_pattern + with a deprecation warning. + pre_process (bool, optional): Include embedding layer + (used with pipeline parallelism). Defaults to True. post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. fp16_lm_cross_entropy (bool, optional): Defaults to False. @@ -65,6 +78,7 @@ class MambaModel(LanguageModule): interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None. pg_collection (ProcessGroupCollection, optional): Model communication process groups. + vp_stage (Optional[int], optional): Virtual pipeline stage index. Defaults to None. """ def __init__( @@ -73,10 +87,11 @@ def __init__( mamba_stack_spec: ModuleSpec, vocab_size: int, max_sequence_length: int, + hybrid_layer_pattern: Optional[str] = None, + hybrid_attention_ratio: Optional[float] = None, + hybrid_mlp_ratio: Optional[float] = None, + hybrid_override_pattern: Optional[str] = None, pre_process: bool = True, - hybrid_attention_ratio: float = 0.0, - hybrid_mlp_ratio: float = 0.0, - hybrid_override_pattern: str = None, post_process: bool = True, fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, @@ -98,10 +113,8 @@ def __init__( self.mamba_stack_spec: ModuleSpec = mamba_stack_spec self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length + self.hybrid_layer_pattern = hybrid_layer_pattern self.pre_process = pre_process - self.hybrid_attention_ratio = hybrid_attention_ratio - self.hybrid_mlp_ratio = hybrid_mlp_ratio - self.hybrid_override_pattern = hybrid_override_pattern self.post_process = post_process self.fp16_lm_cross_entropy = fp16_lm_cross_entropy self.parallel_output = parallel_output @@ -109,18 +122,63 @@ def __init__( self.position_embedding_type = position_embedding_type self.vp_stage = vp_stage - # Parse unified pattern to extract main and MTP components - from megatron.core.ssm.mamba_hybrid_layer_allocation import parse_hybrid_pattern + # Backward compatibility for deprecated hybrid parameters + if hybrid_override_pattern is not None: + if self.hybrid_layer_pattern is None: + log_single_rank( + logger, + logging.WARNING, + "hybrid_override_pattern has been deprecated. " + "Use hybrid_layer_pattern instead.", + ) + self.hybrid_layer_pattern = hybrid_override_pattern + else: + raise ValueError( + "hybrid_override_pattern and hybrid_layer_pattern cannot both be set. " + "hybrid_override_pattern has been deprecated; use hybrid_layer_pattern instead." + ) + if (hybrid_attention_ratio is not None and hybrid_attention_ratio > 0.0) or ( + hybrid_mlp_ratio is not None and hybrid_mlp_ratio > 0.0 + ): + log_single_rank( + logger, + logging.WARNING, + "hybrid_attention_ratio and hybrid_mlp_ratio have been deprecated. " + "Use hybrid_layer_pattern instead.", + ) + if self.hybrid_layer_pattern is None: + from megatron.core.ssm.mamba_hybrid_layer_allocation import pattern_from_ratios + + attn_ratio = hybrid_attention_ratio if hybrid_attention_ratio else 0.0 + mlp_ratio = hybrid_mlp_ratio if hybrid_mlp_ratio else 0.0 + self.hybrid_layer_pattern = pattern_from_ratios( + config.num_layers, attn_ratio, mlp_ratio + ) + + # Parse unified pattern to extract main and MTP components, and + # determine the pipeline segment for this model instance. + from megatron.core.ssm.mamba_hybrid_layer_allocation import ( + parse_hybrid_pattern, + select_pipeline_segment, + ) - parsed = parse_hybrid_pattern(hybrid_override_pattern) + parsed = parse_hybrid_pattern(self.hybrid_layer_pattern) self.mtp_pattern = parsed.mtp_pattern self.mtp_num_depths = parsed.mtp_num_depths + layer_type_list, layer_offset = select_pipeline_segment( + parsed.main_pattern or '', self.pg_collection.pp, vp_stage + ) + # Determine if MTP is needed (based on pattern parsing) self.mtp_process = ( self.mtp_pattern is not None and self.mtp_num_depths > 0 - and mtp_on_this_rank(self.config, vp_stage=self.vp_stage) + # The following forces MTP to be on the final pipeline stage. It might be more optimal + # to split the hybrid layer pattern into pipeline stages before parsing the pattern for + # the current pipeline stage. This could also enable MTP standalone (MTP in a pipeline + # stage separate from loss) to be supported in the hybrid model. + and mtp_on_this_rank(self.config, ignore_virtual=False, vp_stage=self.vp_stage) ) # megatron core pipelining currently depends on model type @@ -151,9 +209,8 @@ def __init__( mamba_stack_spec, self.config, pre_process=self.pre_process, - hybrid_attention_ratio=self.hybrid_attention_ratio, - hybrid_mlp_ratio=self.hybrid_mlp_ratio, - hybrid_override_pattern=parsed.main_pattern, + layer_type_list=layer_type_list, + pp_layer_offset=layer_offset, post_process=self.post_process, dtype=config.params_dtype, pg_collection=self.pg_collection, diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index d3e5d5e26f8..ecc0aaf5f9d 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -115,9 +115,7 @@ def __init__( language_rotary_base: int = 10000, language_rope_scaling: bool = False, language_rope_scaling_factor: float = 8.0, - hybrid_attention_ratio: float = 1.0, - hybrid_mlp_ratio: float = 1.0, - hybrid_override_pattern: str = None, + hybrid_layer_pattern: str = None, fp16_lm_cross_entropy: bool = False, image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX, pixel_shuffle: bool = False, @@ -206,9 +204,7 @@ def __init__( parallel_output=parallel_output, position_embedding_type=language_position_embedding_type, pre_process=self.pre_process, - hybrid_attention_ratio=hybrid_attention_ratio, - hybrid_mlp_ratio=hybrid_mlp_ratio, - hybrid_override_pattern=hybrid_override_pattern, + hybrid_layer_pattern=hybrid_layer_pattern, post_process=self.post_process, rotary_percent=language_rotary_percent, rotary_base=language_rotary_base, diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index ef67983d4cf..9ae40100678 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -22,7 +22,6 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols -from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.identity_op import IdentityOp @@ -57,12 +56,11 @@ class MambaStack(GraphableMegatronModule, MegatronModule): in fp32. Defaults to False. pre_process (bool, optional): whether to include an embedding layer. Defaults to True. - hybrid_attention_ratio (float, optional): the target ratio of attention layers to - total layers. Defaults to 0.0. - hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total - layers. Defaults to 0.0. - hybrid_override_pattern (str, optional): the hybrid layer pattern to override - with. Defaults to None. + layer_type_list (list, optional): pre-computed list of layer type symbols for + this pipeline segment. When provided (by MambaModel), pipeline stage + selection has already been done via '|' separators in the pattern. + pp_layer_offset (int, optional): the global layer offset for this pipeline + segment. Defaults to 0. post_layer_norm (bool, optional): whether to include a final layer norm. Defaults to True. post_process (bool, optional): whether to include an output layer. @@ -71,6 +69,7 @@ class MambaStack(GraphableMegatronModule, MegatronModule): dtype (optional): the data type to use. Defaults to None. pg_collection (ProcessGroupCollection): the required model communication process groups to use. + is_mtp_layer (bool, optional): whether this is an MTP layer. Defaults to False. """ def __init__( @@ -79,9 +78,8 @@ def __init__( submodules: MambaStackSubmodules, residual_in_fp32=False, pre_process: bool = True, - hybrid_attention_ratio: float = 0.0, - hybrid_mlp_ratio: float = 0.0, - hybrid_override_pattern: str = None, + layer_type_list: Optional[list[str]] = None, + pp_layer_offset: int = 0, post_layer_norm: bool = True, post_process: bool = True, device=None, @@ -103,38 +101,18 @@ def __init__( # Required for pipeline parallel schedules self.input_tensor = None - - self.hybrid_attention_ratio = hybrid_attention_ratio - self.hybrid_mlp_ratio = hybrid_mlp_ratio - self.hybrid_override_pattern = hybrid_override_pattern self.pg_collection = pg_collection - # For MTP layers, always use pattern length (config.num_layers is for main decoder) - if self.is_mtp_layer: - num_layers_for_allocation = len(self.hybrid_override_pattern) - else: - num_layers_for_allocation = ( - self.config.num_layers - if self.config.num_layers is not None - else len(self.hybrid_override_pattern) - ) - - self.layer_type_list = allocate_layers( - num_layers_for_allocation, - self.hybrid_attention_ratio, - self.hybrid_mlp_ratio, - self.hybrid_override_pattern, - silent=self.is_mtp_layer, + assert layer_type_list is not None, ( + "layer_type_list must be provided. It should be pre-computed from " + "--hybrid-layer-pattern by MambaModel." ) + self.layer_type_list = layer_type_list - pp_layer_offset = 0 - if self.pp_group.size() > 1 and not self.is_mtp_layer: - pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( - self.layer_type_list - ) - # Build main decoder layers using shared layer builder + # Build layers from the pre-selected segment self.layers = nn.ModuleList() for i, layer_type in enumerate(self.layer_type_list): + layer_number = i + 1 + pp_layer_offset if self.config.fp8: quant_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) elif self.config.fp4: @@ -147,34 +125,35 @@ def __init__( submodules.mamba_layer, config=self.config, residual_in_fp32=residual_in_fp32, - layer_number=i + 1 + pp_layer_offset, + layer_number=layer_number, pp_layer_offset=pp_layer_offset, pg_collection=pg_collection, ) elif layer_type == LayerSymbols.ATTENTION: - # Transformer layers apply their own pp_layer_offset layer = build_module( submodules.attention_layer, config=self.config, - layer_number=i + 1, + layer_number=layer_number, pg_collection=pg_collection, is_mtp_layer=is_mtp_layer, + add_layer_offset=False, + pp_layer_offset=pp_layer_offset, ) elif layer_type == LayerSymbols.MLP: - # MLP layers apply their own pp_layer_offset layer = build_module( submodules.mlp_layer, config=self.config, - layer_number=i + 1, + layer_number=layer_number, pg_collection=pg_collection, + add_layer_offset=False, ) elif layer_type == LayerSymbols.MOE: - # MoE layers apply their own pp_layer_offset layer = build_module( submodules.moe_layer, config=self.config, - layer_number=i + 1, + layer_number=layer_number, pg_collection=pg_collection, + add_layer_offset=False, ) else: assert False, "unexpected layer_type" @@ -191,57 +170,6 @@ def __init__( eps=self.config.layernorm_epsilon, ) - def _select_layers_for_pipeline_parallel(self, layer_type_list): - assert self.config.virtual_pipeline_model_parallel_size is None, ( - "The Mamba hybrid model does not currently support " - "virtual/interleaved pipeline parallelism" - ) - - pp_rank = self.pp_group.rank() - pp_size = self.pp_group.size() - - num_layers_in_first = self.config.num_layers_in_first_pipeline_stage - num_layers_in_last = self.config.num_layers_in_last_pipeline_stage - - if num_layers_in_first is not None or num_layers_in_last is not None: - # Uneven pipeline parallelism: mirror the logic in - # get_transformer_layer_offset so that MambaStack and - # TransformerLayer agree on layer placement. - first = 0 if num_layers_in_first is None else num_layers_in_first - last = 0 if num_layers_in_last is None else num_layers_in_last - middle_num_layers = self.config.num_layers - first - last - - middle_pipeline_stages = pp_size - sum( - 1 for x in (num_layers_in_first, num_layers_in_last) if x is not None - ) - - if middle_pipeline_stages > 0: - layers_per_middle = middle_num_layers // middle_pipeline_stages - else: - layers_per_middle = 0 - - is_first_stage = num_layers_in_first is not None and pp_rank == 0 - is_last_stage = num_layers_in_last is not None and pp_rank == pp_size - 1 - - if is_first_stage: - offset = 0 - num_layers_this_rank = first - elif is_last_stage: - offset = self.config.num_layers - last - num_layers_this_rank = last - else: - middle_rank = pp_rank if num_layers_in_first is None else pp_rank - 1 - offset = middle_rank * layers_per_middle + first - num_layers_this_rank = layers_per_middle - else: - num_layers_per_pipeline_rank = self.config.num_layers // pp_size - offset = pp_rank * num_layers_per_pipeline_rank - num_layers_this_rank = num_layers_per_pipeline_rank - - selected_list = layer_type_list[offset : offset + num_layers_this_rank] - - return offset, selected_list - def set_input_tensor(self, input_tensor: Tensor): """Set input tensor to be used instead of forward()'s input. diff --git a/megatron/core/ssm/mamba_hybrid_layer_allocation.py b/megatron/core/ssm/mamba_hybrid_layer_allocation.py index d7002b2915d..de62257574c 100644 --- a/megatron/core/ssm/mamba_hybrid_layer_allocation.py +++ b/megatron/core/ssm/mamba_hybrid_layer_allocation.py @@ -4,34 +4,23 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple -if __name__ != "__main__": - from megatron.core.utils import log_single_rank -else: - from typing import Any - - import torch - - def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): - """Logs a message to the given rank.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == rank: - logger.log(*args, **kwargs) - else: - logger.log(*args, **kwargs) +import torch +from megatron.core.utils import log_on_each_pipeline_stage logger = logging.getLogger(__name__) class Symbols: - """Symbols for different layer types.""" + """Symbols for different layer types and pattern separators.""" MAMBA = "M" ATTENTION = "*" MLP = "-" MOE = 'E' + PIPE = '|' MTP_SEPARATOR = "/" - VALID = {MAMBA, ATTENTION, MLP, MOE} + VALID_LAYERS = {MAMBA, ATTENTION, MLP, MOE} @dataclass @@ -39,7 +28,9 @@ class ParsedHybridPattern: """Result of parsing a unified hybrid pattern string. A unified pattern encodes both the main decoder pattern and the MTP pattern - in a single string using "/" as a separator. + in a single string using "/" as a separator. The main pattern may also + contain "|" pipe symbols to define pipeline stage boundaries for flexible + virtual pipeline parallelism (fVPP). Format: "///..." @@ -47,12 +38,15 @@ class ParsedHybridPattern: - "M*M*" -> main="M*M*", mtp=None, depths=0 (no MTP) - "M*M*/MM/MM" -> main="M*M*", mtp="MM", depths=2 - "MMMM/*M/*M/*M" -> main="MMMM", mtp="*M", depths=3 + - "M-M-|M-M*-/MM/MM" -> main="M-M-|M-M*-" (2 PP stages), mtp="MM", depths=2 The "/" symbol introduces MTP patterns. Each repeated pattern after the main decoder represents one MTP prediction depth. + The "|" symbol in the main pattern defines pipeline stage boundaries. + Attributes: - main_pattern: The main decoder layer pattern (e.g., "M*M*") + main_pattern: The main decoder layer pattern (e.g., "M*M*" or "M-M-|M-M*-") mtp_pattern: The MTP layer pattern per depth (e.g., "MM"), or None if no MTP mtp_num_depths: Number of MTP prediction depths (0 if no MTP) """ @@ -62,12 +56,139 @@ class ParsedHybridPattern: mtp_num_depths: int +def pattern_from_ratios( + num_layers: int, attention_ratio: float = 0.0, mlp_ratio: float = 0.0 +) -> str: + """Convert deprecated ratio arguments to a layer pattern string. + + Generates an evenly-spaced hybrid layer pattern from target attention and MLP + ratios. This exists for backward compatibility with code that uses the deprecated + hybrid_attention_ratio and hybrid_mlp_ratio parameters. + + Args: + num_layers: Total number of layers. + attention_ratio: Target ratio of attention layers to total layers. + mlp_ratio: Target ratio of MLP layers to total layers. + + Returns: + A layer pattern string (e.g., "MMM*MMM*MM"). + """ + assert num_layers > 0 + assert 0.0 <= attention_ratio <= 1.0 + assert 0.0 <= mlp_ratio <= 1.0 + assert attention_ratio + mlp_ratio <= 1.0 + + # Allocate attention layers (evenly spaced, starting and ending with mamba) + attention_count = round(num_layers * attention_ratio) + mamba_count = num_layers - attention_count + sections = attention_count + 1 + section_len = mamba_count / sections + + layer_types = [Symbols.MAMBA] * num_layers + x = section_len + for i in range(num_layers): + if x < 0.5: + layer_types[i] = Symbols.ATTENTION + x += section_len + else: + x -= 1 + + # Allocate MLP layers (evenly distributed, not replacing attention) + mlp_count = round(num_layers * mlp_ratio) + if mlp_count > 0: + mamba_count -= mlp_count + ratio = mamba_count / mlp_count + x = ratio + for i in range(num_layers): + if layer_types[i] == Symbols.MAMBA: + if x < 0.5: + layer_types[i] = Symbols.MLP + x += ratio + else: + x -= 1 + + return ''.join(layer_types) + + +def get_hybrid_total_layer_count(pattern: str) -> int: + """Returns the total number of main decoder layers in a hybrid layer pattern. + + Extracts the main pattern (before the first MTP separator '/'), strips + pipeline stage separators '|', and returns the character count. + + Args: + pattern: Full hybrid layer pattern, possibly including MTP and pipe separators. + + Returns: + Total number of layers in the main decoder pattern. + """ + main_pattern = pattern.split(Symbols.MTP_SEPARATOR)[0] + _validate_pattern(main_pattern, "main", allow_pipe=True) + return len(main_pattern.replace(Symbols.PIPE, '')) + + +def get_hybrid_total_pipeline_segment_count(pattern: str) -> int: + """Returns the number of pipeline segments in a hybrid layer pattern. + + Extracts the main pattern (before the first MTP separator '/') and counts + the number of segments delimited by '|'. + + Args: + pattern: Full hybrid layer pattern, possibly including MTP and pipe separators. + + Returns: + Number of pipeline segments (pipe count + 1). + """ + main_pattern = pattern.split(Symbols.MTP_SEPARATOR)[0] + return main_pattern.count(Symbols.PIPE) + 1 + + +def get_hybrid_layer_counts(pattern: str) -> Dict[str, int]: + """Count layers by type across the full hybrid pattern (main + MTP). + + Parses the pattern to extract main and MTP components, then counts + each layer type. Main pattern '|' separators are skipped. MTP layers + are counted once per MTP depth. + + Args: + pattern: Full hybrid layer pattern string. + + Returns: + Dictionary mapping layer symbol to count. Keys are Symbols.ATTENTION, + Symbols.MAMBA, Symbols.MLP, and Symbols.MOE. + + Examples: + >>> get_hybrid_layer_counts("M*M*") + {'*': 2, 'M': 2, '-': 0, 'E': 0} + + >>> get_hybrid_layer_counts("M-M-|M-M*-/MM/MM") + {'*': 1, 'M': 8, '-': 4, 'E': 0} + """ + parsed = parse_hybrid_pattern(pattern) + counts = {Symbols.ATTENTION: 0, Symbols.MAMBA: 0, Symbols.MLP: 0, Symbols.MOE: 0} + + # Count main decoder layers (skip '|' pipe separators) + if parsed.main_pattern: + for char in parsed.main_pattern: + if char in counts: + counts[char] += 1 + + # Count MTP layers (pattern repeated mtp_num_depths times) + if parsed.mtp_pattern and parsed.mtp_num_depths > 0: + for char in parsed.mtp_pattern: + if char in counts: + counts[char] += parsed.mtp_num_depths + + return counts + + def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: """Parse a unified hybrid pattern string into main and MTP components. The pattern uses "/" as a separator between the main decoder pattern and MTP patterns. Each MTP pattern after the separator represents one prediction - depth. + depth. The main pattern may contain "|" pipe symbols for pipeline stage + boundaries. Format: "///..." @@ -90,6 +211,9 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: >>> parse_hybrid_pattern("MMMM/*M/*M/*M") ParsedHybridPattern(main_pattern="MMMM", mtp_pattern="*M", mtp_num_depths=3) + + >>> parse_hybrid_pattern("M-M-|M-M*-/MM/MM") + ParsedHybridPattern(main_pattern="M-M-|M-M*-", mtp_pattern="MM", mtp_num_depths=2) """ if pattern is None: return ParsedHybridPattern(main_pattern=None, mtp_pattern=None, mtp_num_depths=0) @@ -99,13 +223,13 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: if len(parts) == 1: # No MTP separator found - pattern is main decoder only main_pattern = parts[0] - _validate_pattern(main_pattern, "main") + _validate_pattern(main_pattern, "main", allow_pipe=True) return ParsedHybridPattern(main_pattern=main_pattern, mtp_pattern=None, mtp_num_depths=0) # First part is main decoder pattern main_pattern = parts[0] if main_pattern: - _validate_pattern(main_pattern, "main") + _validate_pattern(main_pattern, "main", allow_pipe=True) # Remaining parts are MTP patterns (one per depth) mtp_parts = parts[1:] @@ -126,7 +250,7 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: f"Full pattern: '{pattern}'" ) - _validate_pattern(mtp_pattern, "MTP") + _validate_pattern(mtp_pattern, "MTP", allow_pipe=False) return ParsedHybridPattern( main_pattern=main_pattern if main_pattern else None, @@ -135,162 +259,97 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: ) -def _validate_pattern(pattern: str, pattern_name: str) -> None: +def _validate_pattern(pattern: str, pattern_name: str, allow_pipe: bool = False) -> None: """Validate that a pattern contains only valid layer symbols. Args: pattern: Layer pattern string to validate pattern_name: Name of pattern for error messages (e.g., "main" or "MTP") + allow_pipe: Whether to allow the pipe '|' separator (for main patterns) Raises: ValueError: If pattern contains invalid symbols """ + valid_chars = Symbols.VALID_LAYERS | {Symbols.PIPE} if allow_pipe else Symbols.VALID_LAYERS for char in pattern: - if char not in Symbols.VALID: + if char not in valid_chars: raise ValueError( f"In {pattern_name} pattern, '{char}' is not a valid layer symbol. " - f"Valid symbols are: {Symbols.VALID}" + f"Valid symbols are: {valid_chars}" ) -def _allocate_auto( - total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float -) -> list: - # First, allocate attention (evenly spaced, starting and ending with mamba) - attention_layers_count: int = round(total_layers_count * target_attention_ratio) - mamba_layers_count: int = total_layers_count - attention_layers_count - mamba_sections_count: int = attention_layers_count + 1 - mamba_section_length: float = mamba_layers_count / mamba_sections_count +def validate_segment_layers(segment: str) -> List[str]: + """Validate and convert a single pipeline segment pattern to a layer type list. - layer_type_list = [Symbols.MAMBA] * total_layers_count - x: float = mamba_section_length - for l in range(total_layers_count): - if x < 0.5: - layer_type_list[l] = Symbols.ATTENTION - x += mamba_section_length - else: - x -= 1 + This is used after the main pattern has been split by '|' into segments. + Each segment should contain only valid layer symbols (no '|'). - # Next, allocate mlp - # (evenly distributed, but right-justified, not replacing attention) - mlp_layers_count: int = round(total_layers_count * target_mlp_ratio) - if mlp_layers_count > 0: - mamba_layers_count -= mlp_layers_count - mamba_to_mlp_ratio: float = mamba_layers_count / mlp_layers_count + Args: + segment: A single pipeline segment pattern string (e.g., "M-M*-") - x: float = mamba_to_mlp_ratio - for l in range(total_layers_count): - if layer_type_list[l] == Symbols.MAMBA: - if x < 0.5: - layer_type_list[l] = Symbols.MLP - x += mamba_to_mlp_ratio - else: - x -= 1 + Returns: + List of layer type characters. + Raises: + ValueError: If segment contains invalid layer symbols. + """ + layer_type_list = list(segment) + for layer_char in layer_type_list: + if layer_char not in Symbols.VALID_LAYERS: + raise ValueError( + f"In hybrid layer pattern segment, '{layer_char}' is not " + f"one of {Symbols.VALID_LAYERS}" + ) return layer_type_list -def _allocate_override(total_layers_count: int, override_pattern: str) -> list: - layer_type_list = list(override_pattern) - override_pattern_length = len(layer_type_list) - if override_pattern_length != total_layers_count: - raise ValueError( - "The hybrid override pattern is the wrong " - f"length: got {override_pattern_length}, expected " - f"{total_layers_count}" - ) - for l in layer_type_list: - if l not in Symbols.VALID: - raise ValueError(f"In hybrid override pattern, '{l}' is not one of {Symbols.VALID}") +def select_pipeline_segment( + main_pattern: str, pp_group: Optional[torch.distributed.ProcessGroup], vp_stage: Optional[int] +) -> Tuple[List[str], int]: + """Select and validate the pipeline segment for the given PP rank and VP stage. - return layer_type_list + Splits the main pattern by '|' into pipeline segments, determines which + segment belongs to this rank based on PP rank and VP stage, validates the + segment's layer symbols, and logs the assignment. + Args: + main_pattern: Main decoder pattern (may contain '|' separators). + Empty string is allowed (produces one empty segment). + pp_group: Pipeline parallel process group, or None if not using PP. + vp_stage: Virtual pipeline stage, or None if not using VPP. -def _layer_counts_match(a: list, b: list) -> bool: - for s in Symbols.VALID: - if a.count(s) != b.count(s): - return False - return True - - -def allocate_layers( - total_layers_count: int, - target_attention_ratio: float, - target_mlp_ratio: float, - override_pattern: str = None, - silent: bool = False, -) -> list: - """Allocates layers according to the requested distribution of layer types.""" - assert total_layers_count > 0 - assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0 - assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0 - assert target_attention_ratio + target_mlp_ratio <= 1.0 - maybe_log_single_rank = (lambda *args, **kwargs: None) if silent else log_single_rank - # Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio - - layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio) - - if override_pattern is not None: - layer_type_list_override = _allocate_override(total_layers_count, override_pattern) - maybe_log_single_rank(logger, logging.INFO, "Using hybrid override pattern") - if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match( - layer_type_list_override, layer_type_list - ): - raise ValueError( - "The number of each type of layer in the override " - "pattern must match the number in the overridden " - "pattern." - ) - if layer_type_list_override == layer_type_list: - maybe_log_single_rank( - logger, logging.INFO, "The override pattern matches the overridden pattern" - ) - else: - maybe_log_single_rank( - logger, logging.INFO, "Warning: overriding pattern A with pattern B" - ) - maybe_log_single_rank(logger, logging.INFO, f"A: {''.join(layer_type_list)}") - maybe_log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}") - layer_type_list = layer_type_list_override - - if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None: - actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION) - actual_attention_ratio = actual_attention_layers_count / total_layers_count - actual_mlp_layers_count = layer_type_list.count(Symbols.MLP) - actual_mlp_ratio = actual_mlp_layers_count / total_layers_count - allocation_string = "".join(layer_type_list) - maybe_log_single_rank( - logger, - logging.INFO, - f"Hybrid allocation ({Symbols.MAMBA} is mamba, " - f"{Symbols.ATTENTION} is attention, " - f"{Symbols.MLP} is mlp):", - ) - maybe_log_single_rank(logger, logging.INFO, allocation_string) - maybe_log_single_rank( - logger, - logging.INFO, - f"{actual_attention_layers_count} attention layers in " - f"{total_layers_count} total layers.", - ) - maybe_log_single_rank( - logger, - logging.INFO, - f"Target attention ratio: {target_attention_ratio:.2f}. " - f"Actual attention ratio: {actual_attention_ratio:.2f}.", - ) - maybe_log_single_rank( - logger, - logging.INFO, - f"{actual_mlp_layers_count} mlp layers in " f"{total_layers_count} total layers.", - ) - maybe_log_single_rank( - logger, - logging.INFO, - f"Target mlp ratio: {target_mlp_ratio:.2f}. " - f"Actual mlp ratio: {actual_mlp_ratio:.2f}.", - ) - return layer_type_list + Returns: + Tuple of (layer_type_list, layer_offset) where layer_type_list is + the list of layer type characters for this segment, and layer_offset + is the sum of layer counts from all preceding segments. + + Raises: + ValueError: If the segment contains invalid layer symbols. + IndexError: If the computed segment index is out of range. + """ + segments = main_pattern.split(Symbols.PIPE) if main_pattern else [''] + + pp_rank = torch.distributed.get_rank(pp_group) if pp_group is not None else 0 + pp_size = torch.distributed.get_world_size(pp_group) if pp_group is not None else 1 + vp_rel = vp_stage if vp_stage is not None else 0 + segment_index = vp_rel * pp_size + pp_rank + + layer_offset = sum(len(segments[i]) for i in range(segment_index)) + my_segment = segments[segment_index] + + layer_type_list = validate_segment_layers(my_segment) + + log_on_each_pipeline_stage( + logger, + logging.INFO, + f"MambaModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_rel}, " + f"segment_index={segment_index}/{len(segments)}, " + f"layers='{my_segment}' ({len(layer_type_list)} layers), " + f"layer_offset={layer_offset}", + ) + + return layer_type_list, layer_offset def get_layer_maps_from_layer_type_list( @@ -307,38 +366,3 @@ def get_layer_maps_from_layer_type_list( local_layer_idx = len(layer_map) layer_map[global_layer_idx] = local_layer_idx return [layer_maps[layer_type] for layer_type in layer_types] - - -if __name__ == "__main__": - test_cases = [ - # (10, 0.2, 0.0), - # (48, 0.0, 0.0), # will not print anything - # (48, 0.1, 0.0), - # 48, 0.3, 0.0), - # (48, 0.5, 0.0), - # (48, 0.6, 0.0), - # (48, 0.7, 0.0), - # (10, 0.0, 0.1), - # (10, 0.0, 0.3), - # (10, 0.0, 0.5), - # (10, 0.1, 0.1), - # (10, 0.2, 0.2), - # (10, 0.3, 0.3), - # (10, 0.5, 0.5), - # (48, 0.2, 0.3), - # (48, 0.5, 0.2), - # (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"), - # (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.5, 0.5), - # (10, 0.3, 0.2, "MMM*-*M*M-"), - # (10, 0.3, 0.2, "MM*M-*M*M-"), - (9, 0.0, 0.0, "M*-M*-M*-"), - (9, 0.0, 0.0, "MMMMMMMMM"), - ] - for t in test_cases: - logging.info("") - allocate_layers(*t) diff --git a/megatron/core/ssm/mlp_layer.py b/megatron/core/ssm/mlp_layer.py index 19aec5878b2..e1668a01381 100644 --- a/megatron/core/ssm/mlp_layer.py +++ b/megatron/core/ssm/mlp_layer.py @@ -20,6 +20,7 @@ def __init__( layer_number: int = 1, hidden_dropout: float = None, pg_collection: Optional[ProcessGroupCollection] = None, + add_layer_offset: bool = True, ): super().__init__( config=config, @@ -27,4 +28,5 @@ def __init__( layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, + add_layer_offset=add_layer_offset, ) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 019c6fef396..28e3dde01c4 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -252,11 +252,13 @@ def __init__( attention_type: str, cp_comm_type: str | None = None, pg_collection: ProcessGroupCollection | None = None, + pp_layer_offset: Optional[int] = None, ): super().__init__(config=config) self.config = config self.layer_number = layer_number + self._pp_layer_offset = pp_layer_offset self.attn_mask_type = attn_mask_type self.attention_type = attention_type @@ -432,7 +434,16 @@ def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype ) def _get_pp_layer_offset_for_inference(self): - """Return the pipeline parallel layer offset for inference.""" + """Return the pipeline parallel layer offset for inference. + + When pp_layer_offset was explicitly provided (e.g. by MambaBlock for + hybrid models using --hybrid-layer-pattern with fVPP), use that value + directly. Otherwise fall back to the standard computation which assumes + uniform layer distribution across pipeline stages. + """ + if self._pp_layer_offset is not None: + return self._pp_layer_offset + assert ( self.config.virtual_pipeline_model_parallel_size is None ), "Virtual pipeline parallelism is not supported for inference" @@ -1251,6 +1262,7 @@ def __init__( attn_mask_type: AttnMaskType = AttnMaskType.padding, cp_comm_type: str | None = None, pg_collection: ProcessGroupCollection | None = None, + pp_layer_offset: Optional[int] = None, ): super().__init__( config=config, @@ -1260,6 +1272,7 @@ def __init__( attention_type="self", cp_comm_type=cp_comm_type, pg_collection=pg_collection, + pp_layer_offset=pp_layer_offset, ) self.linear_qkv_out_dim = self.query_projection_size + 2 * self.kv_projection_size diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 9d588c04860..bcef07de73f 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -801,11 +801,13 @@ def __init__( # 2. GPT path: single TransformerLayer if mtp_layer_pattern is not None and mamba_submodules is not None: from megatron.core.ssm.mamba_block import MambaStack + from megatron.core.ssm.mamba_hybrid_layer_allocation import validate_segment_layers self.mtp_model_layer = MambaStack( config=self.config, submodules=mamba_submodules, - hybrid_override_pattern=mtp_layer_pattern, + layer_type_list=validate_segment_layers(mtp_layer_pattern), + pp_layer_offset=0, pre_process=True, # Always receives input from eh_proj post_layer_norm=False, # MTP has its own final_layernorm post_process=True, # MTP layer is self-contained diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1aad3c4b89f..2b0a66ae268 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -63,7 +63,7 @@ class TransformerConfig(ModelParallelConfig): """Use a single MTP layer repeatedly instead of multiple separate layers.""" mtp_hybrid_override_pattern: Optional[str] = None - """DEPRECATED: Use unified hybrid_override_pattern instead. + """DEPRECATED: Use unified hybrid_layer_pattern instead. Legacy argument for loading old checkpoints. Force a specific hybrid layer pattern for MTP layers. """ diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 9f7fa1397d2..68bcdd1e8ad 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -275,6 +275,8 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, ): self.submodules_config = submodules super().__init__(config=config, vp_stage=vp_stage) @@ -288,7 +290,10 @@ def __init__( # so they should NOT add the decoder layer offset. The router.py handles MTP layer # numbering separately by adding config.num_layers to distinguish MTP layers from decoder # layers in the aux loss tracker. - if is_mtp_layer: + # + # When add_layer_offset is False, the caller has already included the correct offset + # in layer_number (e.g. when using --hybrid-layer-pattern with fVPP). + if is_mtp_layer or not add_layer_offset: self.layer_number = layer_number else: self.layer_number = layer_number + get_transformer_layer_offset( @@ -313,6 +318,8 @@ def __init__( attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type attention_optional_kwargs["pg_collection"] = pg_collection + if pp_layer_offset is not None: + attention_optional_kwargs["pp_layer_offset"] = pp_layer_offset # [Module 2: SelfAttention] self.self_attention = build_module( diff --git a/megatron/post_training/model_builder.py b/megatron/post_training/model_builder.py index 7d38041a7a3..085d188e811 100644 --- a/megatron/post_training/model_builder.py +++ b/megatron/post_training/model_builder.py @@ -106,6 +106,11 @@ def _load_teacher_model_config(checkpoint_path: str) -> Namespace: del args_dict["kv_channels"] # not recalculated if present args_dict.update(config) + # Backward compat: old checkpoints have hybrid_override_pattern but not hybrid_layer_pattern + if (args_dict.get('hybrid_override_pattern') is not None + and args_dict.get('hybrid_layer_pattern') is None): + args_dict['hybrid_layer_pattern'] = args_dict['hybrid_override_pattern'] + return Namespace(**args_dict) @@ -114,13 +119,10 @@ def _load_teacher_model(config, config_raw: Namespace, model_kwargs: Dict[str, A args = get_args() if config.is_hybrid_model: - # These parameters are not part of the TransformerConfig and need to be passed separately. - if "hybrid_override_pattern" in config_raw: - model_kwargs["hybrid_override_pattern"] = config_raw.hybrid_override_pattern - if "hybrid_attention_ratio" in config_raw: - model_kwargs["hybrid_attention_ratio"] = config_raw.hybrid_attention_ratio - if "hybrid_mlp_ratio" in config_raw: - model_kwargs["hybrid_mlp_ratio"] = config_raw.hybrid_mlp_ratio + # This parameter is not part of the TransformerConfig and needs to be passed separately. + # Note: hybrid_override_pattern is remapped to hybrid_layer_pattern in + # _load_teacher_model_config, so config_raw.hybrid_layer_pattern is always set here. + model_kwargs["hybrid_layer_pattern"] = config_raw.hybrid_layer_pattern teacher = MCoreMambaModel(config=config, **model_kwargs) else: @@ -257,7 +259,7 @@ def modelopt_gpt_mamba_builder( "pg_collection": pg_collection, } model = MCoreGPTModel(config=config, **model_kwargs) - elif args.export_model_type == "MambaModel" or args.is_hybrid_model: + elif args.export_model_type == "MambaModel" or getattr(args, 'hybrid_layer_pattern', None) is not None: from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec if args.export_default_te_spec and args.export_te_mcore_model: @@ -275,10 +277,8 @@ def modelopt_gpt_mamba_builder( "mamba_stack_spec": mamba_stack_spec, "vocab_size": args.padded_vocab_size, "max_sequence_length": args.max_position_embeddings, + "hybrid_layer_pattern": args.hybrid_layer_pattern, "pre_process": pre_process, - "hybrid_attention_ratio": args.hybrid_attention_ratio, - "hybrid_mlp_ratio": args.hybrid_mlp_ratio, - "hybrid_override_pattern": args.hybrid_override_pattern, "post_process": post_process, "fp16_lm_cross_entropy": args.fp16_lm_cross_entropy, "parallel_output": True, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bd6143409e7..dad4db9189f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -553,35 +553,123 @@ def validate_args(args, defaults={}): print_rank_0('setting global batch size to {}'.format(args.global_batch_size)) assert args.global_batch_size > 0 - # === MTP validation === - # Deprecation warnings for legacy MTP arguments + # === Hybrid layer pattern: deprecation handling and validation === + + # Backward compat: --hybrid-override-pattern is deprecated in favor of --hybrid-layer-pattern + used_hybrid_override_pattern = False + if args.hybrid_override_pattern is not None: + assert args.hybrid_layer_pattern is None, ( + '--hybrid-override-pattern and --hybrid-layer-pattern cannot both be specified. ' + '--hybrid-override-pattern is deprecated; use --hybrid-layer-pattern instead.' + ) + warn_rank_0( + "--hybrid-override-pattern is deprecated. Use --hybrid-layer-pattern instead.", + args.rank, + ) + args.hybrid_layer_pattern = args.hybrid_override_pattern + used_hybrid_override_pattern = True + if args.mtp_hybrid_override_pattern is not None: warn_rank_0( "--mtp-hybrid-override-pattern is deprecated. " - "For new hybrid models with MTP models, use unified --hybrid-override-pattern instead. " + "For new hybrid models with MTP, use unified --hybrid-layer-pattern instead. " "Example: 'M*M*/MM/MM' means main='M*M*', MTP pattern='MM' with 2 depths. " "This argument is kept only for loading old checkpoints.", args.rank, ) - # Backward compatibility: convert legacy mtp_hybrid_override_pattern to unified format - from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols, parse_hybrid_pattern + from megatron.core.ssm.mamba_hybrid_layer_allocation import ( + Symbols, parse_hybrid_pattern, get_hybrid_total_layer_count, + get_hybrid_total_pipeline_segment_count, + ) sep = Symbols.MTP_SEPARATOR + + # Backward compat: convert legacy mtp_hybrid_override_pattern to unified format if ( - getattr(args, 'mtp_hybrid_override_pattern', None) is not None + args.mtp_hybrid_override_pattern is not None and args.mtp_num_layers is not None and args.mtp_num_layers > 0 - and (args.hybrid_override_pattern is None or sep not in args.hybrid_override_pattern) + and (args.hybrid_layer_pattern is None or sep not in args.hybrid_layer_pattern) ): - main_pattern = args.hybrid_override_pattern or '' + main_pattern = args.hybrid_layer_pattern or '' mtp_pattern = args.mtp_hybrid_override_pattern - args.hybrid_override_pattern = main_pattern + sep + sep.join([mtp_pattern] * args.mtp_num_layers) + args.hybrid_layer_pattern = main_pattern + sep + sep.join([mtp_pattern] * args.mtp_num_layers) args.mtp_hybrid_override_pattern = None - print_rank_0(f"Converted legacy MTP pattern to unified: {args.hybrid_override_pattern}") + print_rank_0(f"Converted legacy MTP pattern to unified: {args.hybrid_layer_pattern}") + + if args.hybrid_layer_pattern is not None: + # Derive num_layers from pattern + num_layers_in_pattern = get_hybrid_total_layer_count(args.hybrid_layer_pattern) + if args.num_layers is not None: + if used_hybrid_override_pattern: + assert args.num_layers == num_layers_in_pattern, ( + f'--num-layers ({args.num_layers}) does not match the number of layers ' + f'derived from --hybrid-override-pattern ({num_layers_in_pattern}). ' + f'Please correct --num-layers or the pattern.' + ) + else: + assert False, ( + 'If --hybrid-layer-pattern is specified, --num-layers should not be specified. ' + 'The number of layers is derived from the pattern.' + ) + args.num_layers = num_layers_in_pattern + + # These arguments are incompatible with --hybrid-layer-pattern + assert args.decoder_first_pipeline_num_layers is None, ( + 'If --hybrid-layer-pattern is specified, --decoder-first-pipeline-num-layers ' + 'should not be specified' + ) + assert args.decoder_last_pipeline_num_layers is None, ( + 'If --hybrid-layer-pattern is specified, --decoder-last-pipeline-num-layers ' + 'should not be specified' + ) + assert args.num_layers_per_virtual_pipeline_stage is None, ( + '--num-layers-per-virtual-pipeline-stage should not be used with ' + '--hybrid-layer-pattern. To specify virtual pipelining, describe a number of ' + 'pipeline segments in --hybrid-layer-pattern that is a multiple of ' + '--pipeline-model-parallel-size greater than 1' + ) + assert args.num_virtual_stages_per_pipeline_rank is None, ( + '--num-virtual-stages-per-pipeline-rank should not be used with ' + '--hybrid-layer-pattern. Virtual pipeline stages are derived from the ' + 'number of | segments in the pattern.' + ) + assert args.pipeline_model_parallel_layout is None, ( + '--pipeline-model-parallel-layout should not be used with --hybrid-layer-pattern. ' + 'Pipeline stage layout is defined by | separators in the pattern.' + ) + assert not args.account_for_embedding_in_pipeline_split, ( + '--account-for-embedding-in-pipeline-split should not be used with ' + '--hybrid-layer-pattern. Pipeline stage layout is defined by | separators ' + 'in the pattern.' + ) + assert not args.account_for_loss_in_pipeline_split, ( + '--account-for-loss-in-pipeline-split should not be used with ' + '--hybrid-layer-pattern. Pipeline stage layout is defined by | separators ' + 'in the pattern.' + ) + + # Derive VPP from pipe segments in the pattern + hybrid_pipeline_segments = get_hybrid_total_pipeline_segment_count( + args.hybrid_layer_pattern + ) + assert hybrid_pipeline_segments % args.transformer_pipeline_model_parallel_size == 0, ( + 'The number of hybrid pipeline segments described by --hybrid-layer-pattern must ' + 'be evenly divisible by --pipeline-model-parallel-size. ' + f'Got {hybrid_pipeline_segments} segments and ' + f'{args.transformer_pipeline_model_parallel_size} pipeline parallel size.' + ) + if hybrid_pipeline_segments > args.transformer_pipeline_model_parallel_size: + # Must be set here in order to assign virtual parallel ranks in training.py/get_model + args.virtual_pipeline_model_parallel_size = ( + hybrid_pipeline_segments // args.transformer_pipeline_model_parallel_size + ) + else: + args.virtual_pipeline_model_parallel_size = None # Infer mtp_num_layers from unified pattern - if args.hybrid_override_pattern and sep in args.hybrid_override_pattern: - parsed = parse_hybrid_pattern(args.hybrid_override_pattern) + if args.hybrid_layer_pattern and sep in args.hybrid_layer_pattern: + parsed = parse_hybrid_pattern(args.hybrid_layer_pattern) if parsed.mtp_pattern and parsed.mtp_num_depths > 0: inferred_mtp_num_layers = parsed.mtp_num_depths if args.mtp_num_layers is None: @@ -589,7 +677,8 @@ def validate_args(args, defaults={}): elif args.mtp_num_layers != inferred_mtp_num_layers: warn_rank_0( f"--mtp-num-layers ({args.mtp_num_layers}) conflicts with " - f"MTP depth count ({inferred_mtp_num_layers}) in pattern '{args.hybrid_override_pattern}'. " + f"MTP depth count ({inferred_mtp_num_layers}) in pattern " + f"'{args.hybrid_layer_pattern}'. " f"Using the inferred value ({inferred_mtp_num_layers}).", args.rank ) @@ -604,14 +693,14 @@ def validate_args(args, defaults={}): ) # Validate MTP args for hybrid vs non-hybrid models - if args.is_hybrid_model: + if args.hybrid_layer_pattern is not None: # Mamba/hybrid model MTP validation - if args.mtp_num_layers and not (args.hybrid_override_pattern and sep in args.hybrid_override_pattern): + if args.mtp_num_layers and not (args.hybrid_layer_pattern and sep in args.hybrid_layer_pattern): # Hybrid model wants MTP but no unified pattern - check for legacy args if args.mtp_hybrid_override_pattern is None: warn_rank_0( "Hybrid model with --mtp-num-layers but no MTP pattern. " - "Use unified --hybrid-override-pattern with '/' separator (e.g., 'M*M*/MM/MM') " + "Use unified --hybrid-layer-pattern with '/' separator (e.g., 'M*M*/MM/MM') " "or legacy --mtp-hybrid-override-pattern for old checkpoints.", args.rank ) @@ -624,8 +713,8 @@ def validate_args(args, defaults={}): "This argument will be ignored.", args.rank ) - # === End of MTP validation === - + # === End of hybrid layer pattern: deprecation handling and validation === + # Uneven virtual pipeline parallelism assert ( int(args.num_layers_per_virtual_pipeline_stage is not None) @@ -680,12 +769,14 @@ def validate_args(args, defaults={}): if args.virtual_pipeline_model_parallel_size == 1: args.virtual_pipeline_model_parallel_size = None else: - args.virtual_pipeline_model_parallel_size = None + # Only set VPP to None if it wasn't already derived from --hybrid-layer-pattern + if args.hybrid_layer_pattern is None: + args.virtual_pipeline_model_parallel_size = None if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None: # Divisibility check not applicable for T5 models which specify encoder_num_layers - # and decoder_num_layers. - if args.num_layers is not None: + # and decoder_num_layers, or for hybrid models using --hybrid-layer-pattern. + if args.num_layers is not None and args.hybrid_layer_pattern is None: num_layers = args.num_layers if args.account_for_embedding_in_pipeline_split: @@ -1521,8 +1612,8 @@ def core_transformer_config_from_args(args, config_class=None): if len(args.cp_comm_type) == 1: kw_args['cp_comm_type'] = args.cp_comm_type[0] - if args.is_hybrid_model: - kw_args['is_hybrid_model'] = args.is_hybrid_model + if args.hybrid_layer_pattern is not None: + kw_args['is_hybrid_model'] = True kw_args['inference_sampling_seed'] = args.seed @@ -2937,20 +3028,16 @@ def _add_experimental_args(parser): 'To use local spec specify local as the argument.' 'For more details, see the model class, ' '`transformer_block.py`, or `transformer_layer.py`') - group.add_argument('--hybrid-attention-ratio', type=float, default=0.0, - help='Ratio of attention layers to total layers, in the ' - 'range [0.0, 1.0].') - group.add_argument('--hybrid-mlp-ratio', type=float, default=0.0, - help='Ratio of mlp layers to total layers, in the ' - 'range [0.0, 1.0].') + group.add_argument('--hybrid-layer-pattern', type=str, default=None, + help='Specify a hybrid layer pattern using M (mamba), * (attention), ' + '- (mlp), E (moe). Use | to define pipeline stage boundaries for ' + 'flexible virtual pipeline parallel (fVPP). Use / to separate MTP ' + 'patterns. Example: "M-M-|M-M*-|M-M-|M-M*-" or "M-M-|M-M*-/MM/MM". ' + 'When this flag is used, it is the sole indicator that a hybrid model ' + 'is being run.') group.add_argument('--hybrid-override-pattern', type=str, default=None, - help='Force a specific hybrid layer pattern. The value' - 'should be a string of characters chosen from' - 'core.ssm.mamba_hybrid_layer_allocation.Symbols.' - 'If a value greater than 0.0 is supplied to any of the ' - 'hybrid ratio arguments, then the number of each type' - 'of layer in the override pattern must match number in' - 'the overidden pattern') + help='Deprecated. Use --hybrid-layer-pattern instead. ' + 'If specified, its value will be forwarded to --hybrid-layer-pattern.') group.add_argument('--yaml-cfg', type=str, default=None, help = 'Config file to add additional arguments') diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 8ced7d267d6..b192318f59a 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1375,6 +1375,18 @@ def load_args_from_checkpoint( checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear') ) + # Backward compat: old checkpoints have hybrid_override_pattern but not hybrid_layer_pattern + if (getattr(checkpoint_args, 'hybrid_override_pattern', None) is not None + and getattr(checkpoint_args, 'hybrid_layer_pattern', None) is None): + setattr( + checkpoint_args, 'hybrid_layer_pattern', + getattr(checkpoint_args, 'hybrid_override_pattern'), + ) + # num_layers is now derived from hybrid_layer_pattern in validate_args, and should not be + # set at the same time as hybrid_layer_pattern. + if hasattr(checkpoint_args, 'num_layers'): + setattr(checkpoint_args, 'num_layers', None) + def _set_arg(arg_name, old_arg_name=None, force=False): if not force and getattr(args, arg_name, None) is not None: return @@ -1417,16 +1429,12 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('attention_dropout', force=True) _set_arg('hidden_dropout', force=True) - _set_arg('hybrid_override_pattern', force=True) - # Legacy MTP pattern for old checkpoints _set_arg('mtp_hybrid_override_pattern', force=True) _set_arg('mtp_num_layers', force=True) _set_arg('mtp_use_repeated_layer', force=True) _set_arg('spec', force=True) - _set_arg('hybrid_attention_ratio', force=True) - _set_arg('hybrid_mlp_ratio', force=True) _set_arg('num_experts', force=True) _set_arg('moe_layer_freq', force=True) @@ -1448,7 +1456,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('mamba_head_dim', force=True) _set_arg('mamba_num_groups', force=True) _set_arg('mamba_num_heads', force=True) - _set_arg('is_hybrid_model', force=True) + # We need to be able to override hybrid_layer_pattern from the command-line so that different + # pipelining can be specified when re-loading a model (e.g. for inference or post-training). + _set_arg('hybrid_layer_pattern') # Heterogeneous args. _set_arg('heterogeneous_layers_config_path', force=True) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 7d4043b62d7..d3b40e8d533 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -4,7 +4,7 @@ import math -from .utils import print_rank_0 +from .utils import is_hybrid_model, print_rank_0 NUM_BYTES_IN_MEGABYTE = 1024 * 1024 @@ -338,7 +338,7 @@ def compute_activation_memory_without_sp(args, num_microbatches, verbose=False): def report_theoretical_memory(args, num_microbatches=None, verbose=False): - if args.is_hybrid_model: + if is_hybrid_model(args): print("Theoretical memory footprints not yet supported for hybrid Mamba-Transformer models.") return diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c68c70735d..8dacc10e035 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -137,7 +137,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import initialize_megatron from megatron.training.initialize import write_args_to_tensorboard from megatron.training.initialize import set_jit_fusion_options -from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank +from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, is_hybrid_model from megatron.training.datasets.data_samplers import build_pretraining_data_loader from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler @@ -226,31 +226,6 @@ def print_datetime(string, override_timestamp=None): print_rank_0(f'[{string}] datetime: {time_str} ') def num_floating_point_operations(args, batch_size): - def calculate_layer_counts(): - """Calculate the number of attention, Mamba, and MLP layers.""" - if args.hybrid_override_pattern: - from megatron.core.ssm.mamba_hybrid_layer_allocation import parse_hybrid_pattern - # Parse unified pattern to separate main and MTP components - parsed = parse_hybrid_pattern(args.hybrid_override_pattern) - counts = {'M': 0, '*': 0, '-': 0, 'E': 0} - # Count main decoder layers - if parsed.main_pattern: - for layer_type in parsed.main_pattern: - if layer_type in counts: - counts[layer_type] += 1 - # Count MTP layers (pattern repeated mtp_num_depths times) - if parsed.mtp_pattern and parsed.mtp_num_depths > 0: - for layer_type in parsed.mtp_pattern: - if layer_type in counts: - counts[layer_type] += parsed.mtp_num_depths - return counts['*'], counts['M'], counts['-'], counts['E'] - else: - num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio) - num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio) - num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers - num_moe_layers = 0 - return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 @@ -607,9 +582,14 @@ def transformer_flops(): return total_floating_point_operations # Main entrypoint for FLOPs calculation. - if args.is_hybrid_model: + if is_hybrid_model(args): # Calculate the number of each type of layer. - num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = calculate_layer_counts() + from operator import itemgetter + + from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols, get_hybrid_layer_counts + num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = itemgetter( + Symbols.ATTENTION, Symbols.MAMBA, Symbols.MLP, Symbols.MOE + )(get_hybrid_layer_counts(args.hybrid_layer_pattern)) mtp_num_layers = args.mtp_num_layers if mtp_num_layers is None: @@ -2032,8 +2012,13 @@ def training_log( if args.moe_z_loss_coeff is not None: track_names.append("z_loss") - if args.is_hybrid_model: - layers = args.hybrid_override_pattern.count('E') + if is_hybrid_model(args): + from operator import itemgetter + + from megatron.core.ssm.mamba_hybrid_layer_allocation import ( + Symbols, get_hybrid_layer_counts, + ) + layers = itemgetter(Symbols.MOE)(get_hybrid_layer_counts(args.hybrid_layer_pattern)) else: layers = args.num_layers diff --git a/megatron/training/utils.py b/megatron/training/utils.py index edd50dc831f..f0a5cac3176 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -431,6 +431,11 @@ def print_rank_last(message): print(message, flush=True) +def is_hybrid_model(args): + """Returns True if the model is a hybrid Mamba-Transformer model.""" + return args.hybrid_layer_pattern is not None + + def is_first_or_last_pipeline_stage(vp_stage): """Return True if on first or last pipeline stage, taking into account virtual pipeline parallelism.""" diff --git a/tests/functional_tests/shell_test_utils/_run_training.sh b/tests/functional_tests/shell_test_utils/_run_training.sh index 72fd187d19d..8f848a24add 100644 --- a/tests/functional_tests/shell_test_utils/_run_training.sh +++ b/tests/functional_tests/shell_test_utils/_run_training.sh @@ -124,8 +124,8 @@ else value=$(echo "$value" | sed 's/^\[//;s/\]$//') TRAINING_PARAMS_FROM_CONFIG+="$key $value " - # Case: contains spaces - elif [[ "$value" == *" "* ]]; then + # Case: contains spaces or shell metacharacters + elif [[ "$value" == *" "* || "$value" == *"|"* || "$value" == *"("* || "$value" == *")"* ]]; then TRAINING_PARAMS_FROM_CONFIG+="$key \"$value\" " # Case: default else diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml index 0232bcb30bf..b1ffaaa0147 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml @@ -22,19 +22,17 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --is-hybrid-model: true --model-provider: mamba --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true --init-method-std: 0.014 --position-embedding-type: none - --num-layers: 50 --hidden-size: 2048 --ffn-hidden-size: 11264 --num-attention-heads: 16 --kv-channels: 128 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec --normalization: RMSNorm --swiglu: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml index 7ff5911a877..36148fb30c9 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml @@ -22,19 +22,17 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --is-hybrid-model: true --model-provider: mamba --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true --init-method-std: 0.014 --position-embedding-type: none - --num-layers: 50 --hidden-size: 2048 --ffn-hidden-size: 11264 --num-attention-heads: 16 --kv-channels: 128 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec --normalization: RMSNorm --swiglu: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml index 22d0cbaa3bf..ddb776d0aee 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml @@ -4,14 +4,12 @@ ENV_VARS: NCCL_ALGO: Ring CUBLAS_WORKSPACE_CONFIG: :4096:8 MODEL_ARGS: - --num-layers: 44 --hidden-size: 1024 --num-attention-heads: 16 --group-query-attention: true --num-query-groups: 8 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" - --is-hybrid-model: true --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..b3308e04b06 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json @@ -0,0 +1,287 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 10.98453, + "2": 10.98656, + "3": 10.98558, + "4": 10.97462, + "5": 11.01243, + "6": 11.01486, + "7": 10.99829, + "8": 10.98374, + "9": 10.97958, + "10": 10.95257, + "11": 11.00151, + "12": 10.97264, + "13": 10.97068, + "14": 10.9819, + "15": 10.86751, + "16": 10.86056, + "17": 10.82417, + "18": 10.83853, + "19": 10.82792, + "20": 10.63567, + "21": 10.5832, + "22": 10.34766, + "23": 10.61001, + "24": 10.3489, + "25": 10.24413, + "26": 10.37199, + "27": 10.3839, + "28": 10.34912, + "29": 10.3595, + "30": 9.90123, + "31": 9.46177, + "32": 10.08687, + "33": 10.07688, + "34": 9.63497, + "35": 9.68183, + "36": 9.56636, + "37": 9.80399, + "38": 9.50995, + "39": 9.91757, + "40": 9.32825, + "41": 9.47987, + "42": 9.55419, + "43": 9.02825, + "44": 9.14665, + "45": 8.99067, + "46": 9.05279, + "47": 9.47035, + "48": 9.03541, + "49": 8.57937, + "50": 9.11692 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 20941.0, + "2": 21706.0, + "3": 21312.0, + "4": 21029.0, + "5": 23603.0, + "6": 23558.0, + "7": 23492.0, + "8": 21856.0, + "9": 23107.0, + "10": 19088.0, + "11": 24660.0, + "12": 23306.0, + "13": 24163.0, + "14": 24444.0, + "15": 23148.0, + "16": 23702.0, + "17": 22014.0, + "18": 22378.0, + "19": 23608.0, + "20": 21520.0, + "21": 22232.0, + "22": 18801.0, + "23": 24318.0, + "24": 19502.0, + "25": 19048.0, + "26": 20393.0, + "27": 21793.0, + "28": 22862.0, + "29": 22737.0, + "30": 19741.0, + "31": 16792.0, + "32": 21327.0, + "33": 22863.0, + "34": 21230.0, + "35": 21207.0, + "36": 20330.0, + "37": 22367.0, + "38": 22291.0, + "39": 22436.0, + "40": 23187.0, + "41": 24131.0, + "42": 23488.0, + "43": 21513.0, + "44": 21418.0, + "45": 21854.0, + "46": 22905.0, + "47": 24925.0, + "48": 24925.0, + "49": 25464.0, + "50": 27681.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 2178379776.0, + "2": 2178379776.0, + "3": 2178641920.0, + "4": 2179428352.0, + "5": 2178641920.0, + "6": 2178641920.0, + "7": 2178641920.0, + "8": 2178641920.0, + "9": 2178641920.0, + "10": 2178641920.0, + "11": 2178641920.0, + "12": 2178641920.0, + "13": 2178641920.0, + "14": 2178641920.0, + "15": 2178641920.0, + "16": 2178641920.0, + "17": 2178641920.0, + "18": 2178641920.0, + "19": 2178641920.0, + "20": 2178641920.0, + "21": 2178641920.0, + "22": 2178641920.0, + "23": 2178641920.0, + "24": 2178641920.0, + "25": 2178641920.0, + "26": 2178641920.0, + "27": 2178641920.0, + "28": 2178641920.0, + "29": 2178641920.0, + "30": 2178641920.0, + "31": 2178641920.0, + "32": 2178641920.0, + "33": 2178641920.0, + "34": 2178641920.0, + "35": 2178641920.0, + "36": 2178641920.0, + "37": 2178641920.0, + "38": 2178379776.0, + "39": 2178641920.0, + "40": 2178379776.0, + "41": 2178379776.0, + "42": 2178904064.0, + "43": 2178379776.0, + "44": 2178641920.0, + "45": 2178641920.0, + "46": 2178641920.0, + "47": 2178641920.0, + "48": 2178379776.0, + "49": 2178641920.0, + "50": 2178641920.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 7284575744.0, + "2": 7744184832.0, + "3": 7744184832.0, + "4": 7744709120.0, + "5": 7744709120.0, + "6": 7744709120.0, + "7": 7744709120.0, + "8": 7744709120.0, + "9": 7744709120.0, + "10": 7744709120.0, + "11": 7744709120.0, + "12": 7744709120.0, + "13": 7744709120.0, + "14": 7744709120.0, + "15": 7744709120.0, + "16": 7744709120.0, + "17": 7744709120.0, + "18": 7744709120.0, + "19": 7744709120.0, + "20": 7744709120.0, + "21": 7744709120.0, + "22": 7744709120.0, + "23": 7744709120.0, + "24": 7744709120.0, + "25": 7744709120.0, + "26": 7744709120.0, + "27": 7744709120.0, + "28": 7744709120.0, + "29": 7744709120.0, + "30": 7744709120.0, + "31": 7744709120.0, + "32": 7744709120.0, + "33": 7744709120.0, + "34": 7744709120.0, + "35": 7744709120.0, + "36": 7744709120.0, + "37": 7744709120.0, + "38": 7744709120.0, + "39": 7744709120.0, + "40": 7744709120.0, + "41": 7744709120.0, + "42": 7744709120.0, + "43": 7744709120.0, + "44": 7744709120.0, + "45": 7744709120.0, + "46": 7744709120.0, + "47": 7744709120.0, + "48": 7744709120.0, + "49": 7744709120.0, + "50": 7744709120.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": "nan", + "2": 68.99515, + "3": 0.16958, + "4": 0.16632, + "5": 0.167, + "6": 0.16564, + "7": 0.16619, + "8": 0.16588, + "9": 0.16662, + "10": 0.16785, + "11": 0.82413, + "12": 0.1671, + "13": 0.16722, + "14": 0.16724, + "15": 0.16551, + "16": 0.1671, + "17": 0.16656, + "18": 0.1668, + "19": 0.16522, + "20": 0.16556, + "21": 0.81885, + "22": 0.16567, + "23": 0.16748, + "24": 0.16601, + "25": 0.16584, + "26": 0.16611, + "27": 0.16667, + "28": 0.16529, + "29": 0.1659, + "30": 0.16604, + "31": 0.80768, + "32": 0.16703, + "33": 0.16588, + "34": 0.16788, + "35": 0.16511, + "36": 0.16508, + "37": 0.1652, + "38": 0.16527, + "39": 0.16626, + "40": 0.16583, + "41": 0.81579, + "42": 0.1665, + "43": 0.16683, + "44": 0.16836, + "45": 0.16702, + "46": 0.1654, + "47": 0.16533, + "48": 0.16527, + "49": 0.16499, + "50": 0.16589 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 00000000000..3f2a25be6b4 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,59 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 +MODEL_ARGS: + --hidden-size: 1024 + --num-attention-heads: 16 + --group-query-attention: true + --num-query-groups: 8 + --hybrid-layer-pattern: M-M-M-M*-M-|M-M-M*-M-M-|M-M*-M-M-M-|M*-M-M-M-M- + --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 0 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_SAVE_PATH} + --load: ${CHECKPOINT_LOAD_PATH} + --data-path: ${DATA_PATH}/text/the_pile/shard00/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/text/the_pile/shard00/bpe/vocab.json + --merge-file: ${DATA_PATH}/text/the_pile/shard00/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 2 + --use-distributed-optimizer: true + --overlap-grad-reduce: true + --overlap-param-gather: true + --check-weight-hash-across-dp-replicas-interval: 10 + --ckpt-fully-parallel-load: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --dist-ckpt-optim-fully-reshardable: true + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true + --attention-backend: unfused + --log-memory-to-tensorboard: true +TEST_TYPE: regular diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_a100.json b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_a100.json index db5414bfb90..fe4e1b47237 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_a100.json +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_a100.json @@ -1 +1,287 @@ -{"lm loss": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 10.81478, "5": 10.8517, "10": 10.78749, "15": 10.79506, "20": 10.69119, "25": 10.52293, "30": 10.34604, "35": 10.26168, "40": 10.07199, "45": 9.8098, "50": 9.88336}}, "num-zeros": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 1549.0, "5": 1915.0, "10": 1391.0, "15": 1773.0, "20": 1615.0, "25": 1748.0, "30": 1877.0, "35": 1915.0, "40": 2111.0, "45": 2009.0, "50": 2347.0}}, "mem-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 522846720.0, "5": 522846720.0, "10": 522846720.0, "15": 522846720.0, "20": 522846720.0, "25": 522846720.0, "30": 522846720.0, "35": 522846720.0, "40": 522846720.0, "45": 522846720.0, "50": 522846720.0}}, "mem-max-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 3768846848.0, "5": 3912608256.0, "10": 3912608256.0, "15": 3912608256.0, "20": 3912608256.0, "25": 3912608256.0, "30": 3912608256.0, "35": 3912608256.0, "40": 3912608256.0, "45": 3912608256.0, "50": 3912608256.0}}, "iteration-time": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 14.36782, "5": 0.18832, "10": 0.16735, "15": 0.16595, "20": 0.16466, "25": 0.16564, "30": 0.16594, "35": 0.16362, "40": 0.16524, "45": 0.16382, "50": 0.16329}}} \ No newline at end of file +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 11.00877, + "2": 11.01519, + "3": 11.00482, + "4": 10.99234, + "5": 11.01519, + "6": 11.03124, + "7": 11.00366, + "8": 11.00197, + "9": 11.0015, + "10": 10.98399, + "11": 10.99266, + "12": 10.9843, + "13": 10.98228, + "14": 10.99069, + "15": 10.87369, + "16": 10.86468, + "17": 10.83441, + "18": 10.84799, + "19": 10.82979, + "20": 10.65983, + "21": 10.60262, + "22": 10.37409, + "23": 10.61651, + "24": 10.36412, + "25": 10.25993, + "26": 10.37776, + "27": 10.38284, + "28": 10.35243, + "29": 10.36265, + "30": 9.90501, + "31": 9.48621, + "32": 10.08722, + "33": 10.07604, + "34": 9.64526, + "35": 9.69425, + "36": 9.57868, + "37": 9.80085, + "38": 9.52328, + "39": 9.92115, + "40": 9.33512, + "41": 9.49131, + "42": 9.56855, + "43": 9.03905, + "44": 9.15098, + "45": 8.99463, + "46": 9.06041, + "47": 9.48005, + "48": 9.03809, + "49": 8.58598, + "50": 9.11529 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 21267.0, + "2": 21884.0, + "3": 21326.0, + "4": 20950.0, + "5": 23623.0, + "6": 23932.0, + "7": 23495.0, + "8": 21753.0, + "9": 22899.0, + "10": 19373.0, + "11": 24898.0, + "12": 23215.0, + "13": 24409.0, + "14": 24364.0, + "15": 23531.0, + "16": 23774.0, + "17": 22229.0, + "18": 22401.0, + "19": 23408.0, + "20": 21373.0, + "21": 22326.0, + "22": 19058.0, + "23": 24204.0, + "24": 19277.0, + "25": 19016.0, + "26": 20631.0, + "27": 21847.0, + "28": 23190.0, + "29": 22742.0, + "30": 19683.0, + "31": 16624.0, + "32": 21448.0, + "33": 22649.0, + "34": 20897.0, + "35": 21541.0, + "36": 20787.0, + "37": 22503.0, + "38": 22392.0, + "39": 22121.0, + "40": 23558.0, + "41": 23430.0, + "42": 23131.0, + "43": 22389.0, + "44": 22413.0, + "45": 22360.0, + "46": 23710.0, + "47": 25110.0, + "48": 25559.0, + "49": 25440.0, + "50": 28269.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 1784293376.0, + "2": 1784293376.0, + "3": 1784293376.0, + "4": 1784293376.0, + "5": 1784293376.0, + "6": 1784293376.0, + "7": 1784293376.0, + "8": 1784293376.0, + "9": 1784293376.0, + "10": 1784293376.0, + "11": 1784293376.0, + "12": 1784293376.0, + "13": 1784293376.0, + "14": 1784293376.0, + "15": 1784293376.0, + "16": 1784293376.0, + "17": 1784293376.0, + "18": 1784293376.0, + "19": 1784293376.0, + "20": 1784293376.0, + "21": 1784293376.0, + "22": 1784293376.0, + "23": 1784293376.0, + "24": 1784293376.0, + "25": 1784293376.0, + "26": 1784293376.0, + "27": 1784293376.0, + "28": 1784293376.0, + "29": 1784293376.0, + "30": 1784293376.0, + "31": 1784293376.0, + "32": 1784293376.0, + "33": 1784293376.0, + "34": 1784293376.0, + "35": 1784293376.0, + "36": 1784293376.0, + "37": 1784293376.0, + "38": 1784293376.0, + "39": 1784293376.0, + "40": 1784293376.0, + "41": 1784293376.0, + "42": 1784293376.0, + "43": 1784293376.0, + "44": 1784293376.0, + "45": 1784293376.0, + "46": 1784293376.0, + "47": 1784293376.0, + "48": 1784293376.0, + "49": 1784293376.0, + "50": 1784293376.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 3927581184.0, + "2": 4483710976.0, + "3": 4483712000.0, + "4": 4483712000.0, + "5": 4483712000.0, + "6": 4483712000.0, + "7": 4483712000.0, + "8": 4483712000.0, + "9": 4483712000.0, + "10": 4483712000.0, + "11": 4483712000.0, + "12": 4483712000.0, + "13": 4483712000.0, + "14": 4483712000.0, + "15": 4483712000.0, + "16": 4483712000.0, + "17": 4483712000.0, + "18": 4483712000.0, + "19": 4483712000.0, + "20": 4483712000.0, + "21": 4483712000.0, + "22": 4483712000.0, + "23": 4483712000.0, + "24": 4483712000.0, + "25": 4483712000.0, + "26": 4483712000.0, + "27": 4483712000.0, + "28": 4483712000.0, + "29": 4483712000.0, + "30": 4483712000.0, + "31": 4483712000.0, + "32": 4483712000.0, + "33": 4483712000.0, + "34": 4483712000.0, + "35": 4483712000.0, + "36": 4483712000.0, + "37": 4483712000.0, + "38": 4483712000.0, + "39": 4483712000.0, + "40": 4483712000.0, + "41": 4483712000.0, + "42": 4483712000.0, + "43": 4483712000.0, + "44": 4483712000.0, + "45": 4483712000.0, + "46": 4483712000.0, + "47": 4483712000.0, + "48": 4483712000.0, + "49": 4483712000.0, + "50": 4483712000.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": "nan", + "2": 132.6674, + "3": 0.27848, + "4": 0.26954, + "5": 0.26637, + "6": 0.26739, + "7": 0.26532, + "8": 0.25707, + "9": 0.25629, + "10": 0.25181, + "11": 0.67139, + "12": 0.24953, + "13": 0.25118, + "14": 0.24964, + "15": 0.24974, + "16": 0.25107, + "17": 0.25047, + "18": 0.24929, + "19": 0.24953, + "20": 0.24912, + "21": 0.65954, + "22": 0.24963, + "23": 0.24904, + "24": 0.24833, + "25": 0.24817, + "26": 0.24791, + "27": 0.2476, + "28": 0.25156, + "29": 0.24992, + "30": 0.24744, + "31": 0.66249, + "32": 0.24825, + "33": 0.24942, + "34": 0.24992, + "35": 0.24883, + "36": 0.24938, + "37": 0.24961, + "38": 0.25008, + "39": 0.24859, + "40": 0.24809, + "41": 0.65959, + "42": 0.24801, + "43": 0.24803, + "44": 0.24795, + "45": 0.24849, + "46": 0.25118, + "47": 0.24896, + "48": 0.24909, + "49": 0.24926, + "50": 0.24903 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..fe4e1b47237 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_dev_dgx_h100.json @@ -0,0 +1,287 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 11.00877, + "2": 11.01519, + "3": 11.00482, + "4": 10.99234, + "5": 11.01519, + "6": 11.03124, + "7": 11.00366, + "8": 11.00197, + "9": 11.0015, + "10": 10.98399, + "11": 10.99266, + "12": 10.9843, + "13": 10.98228, + "14": 10.99069, + "15": 10.87369, + "16": 10.86468, + "17": 10.83441, + "18": 10.84799, + "19": 10.82979, + "20": 10.65983, + "21": 10.60262, + "22": 10.37409, + "23": 10.61651, + "24": 10.36412, + "25": 10.25993, + "26": 10.37776, + "27": 10.38284, + "28": 10.35243, + "29": 10.36265, + "30": 9.90501, + "31": 9.48621, + "32": 10.08722, + "33": 10.07604, + "34": 9.64526, + "35": 9.69425, + "36": 9.57868, + "37": 9.80085, + "38": 9.52328, + "39": 9.92115, + "40": 9.33512, + "41": 9.49131, + "42": 9.56855, + "43": 9.03905, + "44": 9.15098, + "45": 8.99463, + "46": 9.06041, + "47": 9.48005, + "48": 9.03809, + "49": 8.58598, + "50": 9.11529 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 21267.0, + "2": 21884.0, + "3": 21326.0, + "4": 20950.0, + "5": 23623.0, + "6": 23932.0, + "7": 23495.0, + "8": 21753.0, + "9": 22899.0, + "10": 19373.0, + "11": 24898.0, + "12": 23215.0, + "13": 24409.0, + "14": 24364.0, + "15": 23531.0, + "16": 23774.0, + "17": 22229.0, + "18": 22401.0, + "19": 23408.0, + "20": 21373.0, + "21": 22326.0, + "22": 19058.0, + "23": 24204.0, + "24": 19277.0, + "25": 19016.0, + "26": 20631.0, + "27": 21847.0, + "28": 23190.0, + "29": 22742.0, + "30": 19683.0, + "31": 16624.0, + "32": 21448.0, + "33": 22649.0, + "34": 20897.0, + "35": 21541.0, + "36": 20787.0, + "37": 22503.0, + "38": 22392.0, + "39": 22121.0, + "40": 23558.0, + "41": 23430.0, + "42": 23131.0, + "43": 22389.0, + "44": 22413.0, + "45": 22360.0, + "46": 23710.0, + "47": 25110.0, + "48": 25559.0, + "49": 25440.0, + "50": 28269.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 1784293376.0, + "2": 1784293376.0, + "3": 1784293376.0, + "4": 1784293376.0, + "5": 1784293376.0, + "6": 1784293376.0, + "7": 1784293376.0, + "8": 1784293376.0, + "9": 1784293376.0, + "10": 1784293376.0, + "11": 1784293376.0, + "12": 1784293376.0, + "13": 1784293376.0, + "14": 1784293376.0, + "15": 1784293376.0, + "16": 1784293376.0, + "17": 1784293376.0, + "18": 1784293376.0, + "19": 1784293376.0, + "20": 1784293376.0, + "21": 1784293376.0, + "22": 1784293376.0, + "23": 1784293376.0, + "24": 1784293376.0, + "25": 1784293376.0, + "26": 1784293376.0, + "27": 1784293376.0, + "28": 1784293376.0, + "29": 1784293376.0, + "30": 1784293376.0, + "31": 1784293376.0, + "32": 1784293376.0, + "33": 1784293376.0, + "34": 1784293376.0, + "35": 1784293376.0, + "36": 1784293376.0, + "37": 1784293376.0, + "38": 1784293376.0, + "39": 1784293376.0, + "40": 1784293376.0, + "41": 1784293376.0, + "42": 1784293376.0, + "43": 1784293376.0, + "44": 1784293376.0, + "45": 1784293376.0, + "46": 1784293376.0, + "47": 1784293376.0, + "48": 1784293376.0, + "49": 1784293376.0, + "50": 1784293376.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 3927581184.0, + "2": 4483710976.0, + "3": 4483712000.0, + "4": 4483712000.0, + "5": 4483712000.0, + "6": 4483712000.0, + "7": 4483712000.0, + "8": 4483712000.0, + "9": 4483712000.0, + "10": 4483712000.0, + "11": 4483712000.0, + "12": 4483712000.0, + "13": 4483712000.0, + "14": 4483712000.0, + "15": 4483712000.0, + "16": 4483712000.0, + "17": 4483712000.0, + "18": 4483712000.0, + "19": 4483712000.0, + "20": 4483712000.0, + "21": 4483712000.0, + "22": 4483712000.0, + "23": 4483712000.0, + "24": 4483712000.0, + "25": 4483712000.0, + "26": 4483712000.0, + "27": 4483712000.0, + "28": 4483712000.0, + "29": 4483712000.0, + "30": 4483712000.0, + "31": 4483712000.0, + "32": 4483712000.0, + "33": 4483712000.0, + "34": 4483712000.0, + "35": 4483712000.0, + "36": 4483712000.0, + "37": 4483712000.0, + "38": 4483712000.0, + "39": 4483712000.0, + "40": 4483712000.0, + "41": 4483712000.0, + "42": 4483712000.0, + "43": 4483712000.0, + "44": 4483712000.0, + "45": 4483712000.0, + "46": 4483712000.0, + "47": 4483712000.0, + "48": 4483712000.0, + "49": 4483712000.0, + "50": 4483712000.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": "nan", + "2": 132.6674, + "3": 0.27848, + "4": 0.26954, + "5": 0.26637, + "6": 0.26739, + "7": 0.26532, + "8": 0.25707, + "9": 0.25629, + "10": 0.25181, + "11": 0.67139, + "12": 0.24953, + "13": 0.25118, + "14": 0.24964, + "15": 0.24974, + "16": 0.25107, + "17": 0.25047, + "18": 0.24929, + "19": 0.24953, + "20": 0.24912, + "21": 0.65954, + "22": 0.24963, + "23": 0.24904, + "24": 0.24833, + "25": 0.24817, + "26": 0.24791, + "27": 0.2476, + "28": 0.25156, + "29": 0.24992, + "30": 0.24744, + "31": 0.66249, + "32": 0.24825, + "33": 0.24942, + "34": 0.24992, + "35": 0.24883, + "36": 0.24938, + "37": 0.24961, + "38": 0.25008, + "39": 0.24859, + "40": 0.24809, + "41": 0.65959, + "42": 0.24801, + "43": 0.24803, + "44": 0.24795, + "45": 0.24849, + "46": 0.25118, + "47": 0.24896, + "48": 0.24909, + "49": 0.24926, + "50": 0.24903 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_lts_dgx_a100.json b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_lts_dgx_a100.json index dc393d0dffc..fe4e1b47237 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_lts_dgx_a100.json +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/golden_values_lts_dgx_a100.json @@ -1 +1,287 @@ -{"lm loss": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 10.82005, "5": 10.85284, "10": 10.78455, "15": 10.7923, "20": 10.69213, "25": 10.5241, "30": 10.34556, "35": 10.26241, "40": 10.07237, "45": 9.811, "50": 9.88419}}, "num-zeros": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 1559.0, "5": 1840.0, "10": 1380.0, "15": 1850.0, "20": 1699.0, "25": 1614.0, "30": 1905.0, "35": 1933.0, "40": 2169.0, "45": 2101.0, "50": 2421.0}}, "mem-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 523004928.0, "5": 523004928.0, "10": 523004928.0, "15": 523004928.0, "20": 523004928.0, "25": 523004928.0, "30": 523004928.0, "35": 523004928.0, "40": 523004928.0, "45": 523004928.0, "50": 523004928.0}}, "mem-max-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 3768873984.0, "5": 3912766464.0, "10": 3912766464.0, "15": 3912766464.0, "20": 3912766464.0, "25": 3912766464.0, "30": 3912766464.0, "35": 3912766464.0, "40": 3912766464.0, "45": 3912766464.0, "50": 3912766464.0}}, "iteration-time": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 18.88705, "5": 0.16956, "10": 0.17448, "15": 0.16853, "20": 0.1715, "25": 0.17071, "30": 0.17343, "35": 0.17213, "40": 0.1719, "45": 0.17357, "50": 0.17228}}} \ No newline at end of file +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 11.00877, + "2": 11.01519, + "3": 11.00482, + "4": 10.99234, + "5": 11.01519, + "6": 11.03124, + "7": 11.00366, + "8": 11.00197, + "9": 11.0015, + "10": 10.98399, + "11": 10.99266, + "12": 10.9843, + "13": 10.98228, + "14": 10.99069, + "15": 10.87369, + "16": 10.86468, + "17": 10.83441, + "18": 10.84799, + "19": 10.82979, + "20": 10.65983, + "21": 10.60262, + "22": 10.37409, + "23": 10.61651, + "24": 10.36412, + "25": 10.25993, + "26": 10.37776, + "27": 10.38284, + "28": 10.35243, + "29": 10.36265, + "30": 9.90501, + "31": 9.48621, + "32": 10.08722, + "33": 10.07604, + "34": 9.64526, + "35": 9.69425, + "36": 9.57868, + "37": 9.80085, + "38": 9.52328, + "39": 9.92115, + "40": 9.33512, + "41": 9.49131, + "42": 9.56855, + "43": 9.03905, + "44": 9.15098, + "45": 8.99463, + "46": 9.06041, + "47": 9.48005, + "48": 9.03809, + "49": 8.58598, + "50": 9.11529 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 21267.0, + "2": 21884.0, + "3": 21326.0, + "4": 20950.0, + "5": 23623.0, + "6": 23932.0, + "7": 23495.0, + "8": 21753.0, + "9": 22899.0, + "10": 19373.0, + "11": 24898.0, + "12": 23215.0, + "13": 24409.0, + "14": 24364.0, + "15": 23531.0, + "16": 23774.0, + "17": 22229.0, + "18": 22401.0, + "19": 23408.0, + "20": 21373.0, + "21": 22326.0, + "22": 19058.0, + "23": 24204.0, + "24": 19277.0, + "25": 19016.0, + "26": 20631.0, + "27": 21847.0, + "28": 23190.0, + "29": 22742.0, + "30": 19683.0, + "31": 16624.0, + "32": 21448.0, + "33": 22649.0, + "34": 20897.0, + "35": 21541.0, + "36": 20787.0, + "37": 22503.0, + "38": 22392.0, + "39": 22121.0, + "40": 23558.0, + "41": 23430.0, + "42": 23131.0, + "43": 22389.0, + "44": 22413.0, + "45": 22360.0, + "46": 23710.0, + "47": 25110.0, + "48": 25559.0, + "49": 25440.0, + "50": 28269.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 1784293376.0, + "2": 1784293376.0, + "3": 1784293376.0, + "4": 1784293376.0, + "5": 1784293376.0, + "6": 1784293376.0, + "7": 1784293376.0, + "8": 1784293376.0, + "9": 1784293376.0, + "10": 1784293376.0, + "11": 1784293376.0, + "12": 1784293376.0, + "13": 1784293376.0, + "14": 1784293376.0, + "15": 1784293376.0, + "16": 1784293376.0, + "17": 1784293376.0, + "18": 1784293376.0, + "19": 1784293376.0, + "20": 1784293376.0, + "21": 1784293376.0, + "22": 1784293376.0, + "23": 1784293376.0, + "24": 1784293376.0, + "25": 1784293376.0, + "26": 1784293376.0, + "27": 1784293376.0, + "28": 1784293376.0, + "29": 1784293376.0, + "30": 1784293376.0, + "31": 1784293376.0, + "32": 1784293376.0, + "33": 1784293376.0, + "34": 1784293376.0, + "35": 1784293376.0, + "36": 1784293376.0, + "37": 1784293376.0, + "38": 1784293376.0, + "39": 1784293376.0, + "40": 1784293376.0, + "41": 1784293376.0, + "42": 1784293376.0, + "43": 1784293376.0, + "44": 1784293376.0, + "45": 1784293376.0, + "46": 1784293376.0, + "47": 1784293376.0, + "48": 1784293376.0, + "49": 1784293376.0, + "50": 1784293376.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 3927581184.0, + "2": 4483710976.0, + "3": 4483712000.0, + "4": 4483712000.0, + "5": 4483712000.0, + "6": 4483712000.0, + "7": 4483712000.0, + "8": 4483712000.0, + "9": 4483712000.0, + "10": 4483712000.0, + "11": 4483712000.0, + "12": 4483712000.0, + "13": 4483712000.0, + "14": 4483712000.0, + "15": 4483712000.0, + "16": 4483712000.0, + "17": 4483712000.0, + "18": 4483712000.0, + "19": 4483712000.0, + "20": 4483712000.0, + "21": 4483712000.0, + "22": 4483712000.0, + "23": 4483712000.0, + "24": 4483712000.0, + "25": 4483712000.0, + "26": 4483712000.0, + "27": 4483712000.0, + "28": 4483712000.0, + "29": 4483712000.0, + "30": 4483712000.0, + "31": 4483712000.0, + "32": 4483712000.0, + "33": 4483712000.0, + "34": 4483712000.0, + "35": 4483712000.0, + "36": 4483712000.0, + "37": 4483712000.0, + "38": 4483712000.0, + "39": 4483712000.0, + "40": 4483712000.0, + "41": 4483712000.0, + "42": 4483712000.0, + "43": 4483712000.0, + "44": 4483712000.0, + "45": 4483712000.0, + "46": 4483712000.0, + "47": 4483712000.0, + "48": 4483712000.0, + "49": 4483712000.0, + "50": 4483712000.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": "nan", + "2": 132.6674, + "3": 0.27848, + "4": 0.26954, + "5": 0.26637, + "6": 0.26739, + "7": 0.26532, + "8": 0.25707, + "9": 0.25629, + "10": 0.25181, + "11": 0.67139, + "12": 0.24953, + "13": 0.25118, + "14": 0.24964, + "15": 0.24974, + "16": 0.25107, + "17": 0.25047, + "18": 0.24929, + "19": 0.24953, + "20": 0.24912, + "21": 0.65954, + "22": 0.24963, + "23": 0.24904, + "24": 0.24833, + "25": 0.24817, + "26": 0.24791, + "27": 0.2476, + "28": 0.25156, + "29": 0.24992, + "30": 0.24744, + "31": 0.66249, + "32": 0.24825, + "33": 0.24942, + "34": 0.24992, + "35": 0.24883, + "36": 0.24938, + "37": 0.24961, + "38": 0.25008, + "39": 0.24859, + "40": 0.24809, + "41": 0.65959, + "42": 0.24801, + "43": 0.24803, + "44": 0.24795, + "45": 0.24849, + "46": 0.25118, + "47": 0.24896, + "48": 0.24909, + "49": 0.24926, + "50": 0.24903 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml index 0983337becc..8cecc7de2ed 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml @@ -4,14 +4,12 @@ ENV_VARS: NCCL_ALGO: Ring CUBLAS_WORKSPACE_CONFIG: :4096:8 MODEL_ARGS: - --num-layers: 44 --hidden-size: 1024 --num-attention-heads: 16 --group-query-attention: true --num-query-groups: 8 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-|M-M-M*-M-M-|M-M*-M-M-M-|M*-M-M-M-M- --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" - --is-hybrid-model: true --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml index 7f7aac5d78b..79b6fd506bd 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml @@ -4,14 +4,12 @@ ENV_VARS: NCCL_ALGO: Ring CUBLAS_WORKSPACE_CONFIG: :4096:8 MODEL_ARGS: - --num-layers: 44 --hidden-size: 1024 --num-attention-heads: 16 --group-query-attention: true --num-query-groups: 8 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" - --is-hybrid-model: true --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml index 93418f580fc..7e16a27960f 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml @@ -4,14 +4,12 @@ ENV_VARS: NCCL_ALGO: Ring CUBLAS_WORKSPACE_CONFIG: :4096:8 MODEL_ARGS: - --num-layers: 44 --hidden-size: 1024 --num-attention-heads: 16 --group-query-attention: true --num-query-groups: 8 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" - --is-hybrid-model: true --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml index 5bc40afede4..26708b32a60 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml @@ -22,19 +22,17 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --is-hybrid-model: true --model-provider: mamba --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true --init-method-std: 0.014 --position-embedding-type: none - --num-layers: 50 --hidden-size: 2048 --ffn-hidden-size: 11264 --num-attention-heads: 16 --kv-channels: 128 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec --normalization: RMSNorm --swiglu: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml index b5c3c409605..3964bcb8ecb 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml @@ -22,19 +22,17 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --is-hybrid-model: true --model-provider: mamba --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true --init-method-std: 0.014 --position-embedding-type: none - --num-layers: 50 --hidden-size: 2048 --ffn-hidden-size: 11264 --num-attention-heads: 16 --kv-channels: 128 - --hybrid-override-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec --normalization: RMSNorm --swiglu: true diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release/model_config.yaml index a77d456506c..fb159dd8839 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release/model_config.yaml @@ -21,7 +21,7 @@ MODEL_ARGS: --distributed-timeout-minutes: 60 --tensor-model-parallel-size: 2 --pipeline-model-parallel-size: 4 - --pipeline-model-parallel-layout: Et*2\\|\\(tt\\|\\)*5t\\|tmL # Et*2|(tt|)*5t|tmL + --pipeline-model-parallel-layout: Et*2|(tt|)*5t|tmL --expert-model-parallel-size: 16 --context-parallel-size: 1 --expert-tensor-parallel-size: 1 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release_sm/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release_sm/model_config.yaml index beb3633b510..3356b18ef77 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_gb_200_release_sm/model_config.yaml @@ -21,7 +21,7 @@ MODEL_ARGS: --distributed-timeout-minutes: 60 --tensor-model-parallel-size: 2 --pipeline-model-parallel-size: 4 - --pipeline-model-parallel-layout: Et*2\\|\\(tt\\|\\)*5t\\|tmL # Et*2|(tt|)*5t|tmL + --pipeline-model-parallel-layout: Et*2|(tt|)*5t|tmL --expert-model-parallel-size: 16 --context-parallel-size: 1 --expert-tensor-parallel-size: 1 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml index ab618ab915c..e504bcb1320 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml @@ -19,7 +19,7 @@ MODEL_ARGS: --distributed-timeout-minutes: 60 --tensor-model-parallel-size: 2 --pipeline-model-parallel-size: 4 - --pipeline-model-parallel-layout: Et*2\\|\\(tt\\|\\)*5t\\|tmL # Et*2|(tt|)*5t|tmL + --pipeline-model-parallel-layout: Et*2|(tt|)*5t|tmL --expert-model-parallel-size: 16 --context-parallel-size: 1 --expert-tensor-parallel-size: 1 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml index ab1df7f1d1e..49cca71a596 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml @@ -19,7 +19,7 @@ MODEL_ARGS: --distributed-timeout-minutes: 60 --tensor-model-parallel-size: 2 --pipeline-model-parallel-size: 4 - --pipeline-model-parallel-layout: Et*2\\|\\(tt\\|\\)*5t\\|tmL # Et*2|(tt|)*5t|tmL + --pipeline-model-parallel-layout: Et*2|(tt|)*5t|tmL --expert-model-parallel-size: 16 --context-parallel-size: 1 --expert-tensor-parallel-size: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_selective_recompute_experimental/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_selective_recompute_experimental/model_config.yaml index efb1fedf93c..e7971347f02 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_selective_recompute_experimental/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_selective_recompute_experimental/model_config.yaml @@ -46,7 +46,7 @@ MODEL_ARGS: # Add network size args --num-layers: 16 --moe-layer-freq: ([0]*3+[1]*13) - --pipeline-model-parallel-layout: Et*3\\|\\(tt\\|\\)*6tmL # Et*3|(tt|)*6tmL + --pipeline-model-parallel-layout: Et*3|(tt|)*6tmL --hidden-size: 1024 --ffn-hidden-size: 4096 --num-attention-heads: 32 diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml index a37dd0dc658..c04d55564a3 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml @@ -51,7 +51,7 @@ MODEL_ARGS: # Add network size args --num-layers: 15 --moe-layer-freq: ([0]*3+[1]*12) - --pipeline-model-parallel-layout: Et*3\\|\\(tt\\|\\)*6mL # Et*3|(tt|)*6mL + --pipeline-model-parallel-layout: Et*3|(tt|)*6mL --hidden-size: 1024 --ffn-hidden-size: 4096 --num-attention-heads: 32 diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml index da78378ddae..dbfb29ea48c 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml @@ -51,7 +51,7 @@ MODEL_ARGS: # Add network size args --num-layers: 15 --moe-layer-freq: ([0]*3+[1]*12) - --pipeline-model-parallel-layout: Et*3\\|\\(tt\\|\\)*6L # Et*3|(tt|)*6L + --pipeline-model-parallel-layout: Et*3|(tt|)*6L --hidden-size: 1024 --ffn-hidden-size: 4096 --num-attention-heads: 32 diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml index f0d1cc0afd3..fd0d79e0986 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml @@ -33,7 +33,7 @@ MODEL_ARGS: --pipeline-model-parallel-size: 2 --expert-model-parallel-size: 2 --expert-tensor-parallel-size: 2 - --pipeline-model-parallel-layout: Et\\|\\(tt\\|\\)*6mL # Et|(tt|)*6mL + --pipeline-model-parallel-layout: Et|(tt|)*6mL --sequence-parallel: true --num-experts: 8 --use-distributed-optimizer: true diff --git a/tests/test_utils/recipes/h100/mamba.yaml b/tests/test_utils/recipes/h100/mamba.yaml index 456a6cbccf7..703fb53160f 100644 --- a/tests/test_utils/recipes/h100/mamba.yaml +++ b/tests/test_utils/recipes/h100/mamba.yaml @@ -63,13 +63,21 @@ products: # - environment: [lts] # disabled until triton is bumped # scope: [nightly] - # PP functional testing deferred - # - test_case: [hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G] - # products: - # - environment: [dev] - # scope: [mr] - # - environment: [lts] # disabled until triton is bumped - # scope: [nightly] + - test_case: [hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] + # - environment: [lts] # disabled until triton is bumped + # scope: [nightly] + + - test_case: [hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] + # - environment: [lts] # disabled until triton is bumped + # scope: [nightly] - test_case: [hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G] products: diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 02be5c136fd..0a9a9e8a384 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -389,8 +389,9 @@ def _build_test_env(cls, test_config): vocab_size=test_config.vocab_size, max_sequence_length=test_config.max_sequence_length, parallel_output=True, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, + hybrid_layer_pattern=( + "M*-" if pp_size == 1 else "M*-|M*-" + ), # 3 or 6 layers (2 PP stages) pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), ).cuda() diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py index 29e3630d7bb..e11ac385188 100644 --- a/tests/unit_tests/models/test_mamba_model.py +++ b/tests/unit_tests/models/test_mamba_model.py @@ -42,8 +42,7 @@ def setup_method(self, method): mamba_stack_spec=mamba_stack_spec, vocab_size=100, max_sequence_length=4, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, + hybrid_layer_pattern="M*-", # 1 Mamba, 1 attention, 1 MLP ) def teardown_method(self, method): @@ -111,8 +110,7 @@ def test_forward_packed_sequence(self): mamba_stack_spec=mamba_stack_spec, vocab_size=vocab_size, max_sequence_length=12, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, + hybrid_layer_pattern="M*-", # 1 Mamba, 1 attention, 1 MLP ) sequence_length = model.max_sequence_length @@ -247,6 +245,9 @@ def test_with_custom_process_groups(self, tmp_path, tp_size, cp_size, pp_size): tp=tp_group, cp=cp_group, pp=pp_group, embd=embd_group ) + # Build pattern with '|' pipeline stage separators: 3 layers per PP stage + hybrid_layer_pattern = "|".join(["M*-"] * pp_size) + # Configure model with appropriate sizes for parallelism model_config = TransformerConfig( num_layers=3 * pp_size, # Scale layers with PP size @@ -264,8 +265,7 @@ def test_with_custom_process_groups(self, tmp_path, tp_size, cp_size, pp_size): mamba_stack_spec=mamba_stack_spec, vocab_size=128, max_sequence_length=4, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, + hybrid_layer_pattern=hybrid_layer_pattern, pg_collection=pg_collection, ) @@ -319,8 +319,7 @@ def setup_method(self, method): mamba_stack_spec=mamba_stack_spec, vocab_size=128, max_sequence_length=DynamicInferenceContext.TOKEN_ROUNDER, - hybrid_attention_ratio=0.5, - hybrid_mlp_ratio=0.0, + hybrid_layer_pattern="M*", # 1 Mamba, 1 attention ) self.model = Float16Module(self.model.config, self.model) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index f933d811779..15bc4bec341 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -390,7 +390,6 @@ def create_test_args(self): args = parse_args() # The following args would be set from the nano v3 checkpoint. - args.num_layers = 52 args.hidden_size = 2688 args.ffn_hidden_size = 1856 args.num_attention_heads = 32 @@ -413,10 +412,9 @@ def create_test_args(self): args.apply_query_key_layer_scaling = False args.attention_dropout = 0.0 args.hidden_dropout = 0.0 - args.hybrid_override_pattern = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME" + args.hybrid_layer_pattern = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME" + args.hybrid_override_pattern = None args.spec = ["megatron.core.models.mamba.mamba_layer_specs", "mamba_stack_spec"] - args.hybrid_attention_ratio = 0.0 - args.hybrid_mlp_ratio = 0.0 args.num_experts = 128 args.moe_layer_freq = 1 args.moe_ffn_hidden_size = 1856 @@ -431,7 +429,6 @@ def create_test_args(self): args.mamba_head_dim = 64 args.mamba_num_groups = 8 args.mamba_num_heads = 64 - args.is_hybrid_model = True args.tokenizer_type = "TikTokenizer" args.tiktoken_pattern = "v2" args.tokenizer_model = "/mnt/artifacts/model/nemotron6/tokenizers/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json" @@ -495,9 +492,7 @@ def setup_method(self, method): mamba_stack_spec=mamba_stack_spec, vocab_size=args.vocab_size, max_sequence_length=args.seq_length, - hybrid_attention_ratio=args.hybrid_attention_ratio, - hybrid_mlp_ratio=args.hybrid_mlp_ratio, - hybrid_override_pattern=args.hybrid_override_pattern, + hybrid_layer_pattern=args.hybrid_layer_pattern, position_embedding_type=args.position_embedding_type, rotary_base=args.rotary_base, rotary_percent=args.rotary_percent, @@ -515,11 +510,9 @@ def test_constructor(self): assert self.model.pre_process is True, "pre_process should be True" assert self.model.post_process is True, "post_process should be True" - assert self.model.hybrid_attention_ratio == 0.0, "hybrid_attention_ratio should be 0.0" - assert self.model.hybrid_mlp_ratio == 0.0, "hybrid_mlp_ratio should be 0.0" assert ( - self.model.hybrid_override_pattern == args.hybrid_override_pattern - ), f"hybrid_override_pattern should be {args.hybrid_override_pattern}" + self.model.hybrid_layer_pattern == args.hybrid_layer_pattern + ), f"hybrid_layer_pattern should be {args.hybrid_layer_pattern}" num_weights = sum([p.numel() for p in self.model.parameters()]) assert num_weights == 8449294624, f"Expected 8449294624 parameters, got {num_weights}" diff --git a/tests/unit_tests/post_training/test_modelopt_module_spec.py b/tests/unit_tests/post_training/test_modelopt_module_spec.py index 3f6491f835d..585be52f944 100644 --- a/tests/unit_tests/post_training/test_modelopt_module_spec.py +++ b/tests/unit_tests/post_training/test_modelopt_module_spec.py @@ -212,7 +212,7 @@ def setup_method(self, method): mamba_stack_spec=mamba_stack_spec, vocab_size=100, max_sequence_length=4, - hybrid_override_pattern="M*-", + hybrid_layer_pattern="M*-", ) # A Hybrid MambaModel using ModelOpt spec (local + TENorm). @@ -221,7 +221,7 @@ def setup_method(self, method): mamba_stack_spec=get_mamba_stack_modelopt_spec(remap_te_layernorm=True), vocab_size=100, max_sequence_length=4, - hybrid_override_pattern="M*-", + hybrid_layer_pattern="M*-", ) diff --git a/tests/unit_tests/ssm/test_mamba_block.py b/tests/unit_tests/ssm/test_mamba_block.py index 909ee47e836..c65623e08e0 100644 --- a/tests/unit_tests/ssm/test_mamba_block.py +++ b/tests/unit_tests/ssm/test_mamba_block.py @@ -6,7 +6,7 @@ from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.mamba_block import MambaStack -from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols, validate_segment_layers from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig @@ -26,12 +26,13 @@ def setup_method(self, method): def get_pg_collection(self): return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) - def get_mamba_block(self, hybrid_override_pattern): + def get_mamba_block(self, layer_pattern): + layer_type_list = validate_segment_layers(layer_pattern) transformer_config = TransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig # will generate errors. - num_layers=len(hybrid_override_pattern), + num_layers=len(layer_type_list), num_attention_heads=4, use_cpu_initialization=True, ) @@ -39,7 +40,8 @@ def get_mamba_block(self, hybrid_override_pattern): return MambaStack( transformer_config, modules, - hybrid_override_pattern=hybrid_override_pattern, + layer_type_list=layer_type_list, + pp_layer_offset=0, pg_collection=self.get_pg_collection(), ) @@ -48,8 +50,8 @@ def teardown_method(self, method): def test_gpu_forward(self): """Test GPU forward pass.""" - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP - block = self.get_mamba_block(hybrid_override_pattern) + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + block = self.get_mamba_block(layer_pattern) block.cuda() micro_batch_size = 2 sequence_length = 32 @@ -67,13 +69,13 @@ def test_gpu_forward(self): def test_layer_types(self): """ - Make sure that the layer types specified with hybrid_override_pattern + Make sure that the layer types specified with layer_pattern were honored. """ - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP - block = self.get_mamba_block(hybrid_override_pattern) + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + block = self.get_mamba_block(layer_pattern) layers = block.layers - # Note that this matches the order specified by hybrid_override_pattern in setup_method + # Note that this matches the order specified by layer_pattern above assert isinstance(layers[0], MambaLayer) assert isinstance(layers[1], TransformerLayer) assert isinstance(layers[1].self_attention, SelfAttention) @@ -82,8 +84,8 @@ def test_layer_types(self): def test_invalid_layer_types_cause_failure(self): invalid_symbol = '+' - assert invalid_symbol not in Symbols.VALID # sanity check. - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + invalid_symbol - # _allocate_override() in mamba_hybrid_layer_allocation.py throws a ValueError. + assert invalid_symbol not in Symbols.VALID_LAYERS # sanity check. + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + invalid_symbol + # validate_segment_layers() in mamba_hybrid_layer_allocation.py throws a ValueError. with pytest.raises(ValueError): - block = self.get_mamba_block(hybrid_override_pattern) + block = self.get_mamba_block(layer_pattern) diff --git a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py index 77c106c3bee..d891e092ef1 100644 --- a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py @@ -1,85 +1,135 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import math -import re +from unittest.mock import patch import pytest -import torch from megatron.core.ssm.mamba_hybrid_layer_allocation import ( ParsedHybridPattern, Symbols, - allocate_layers, + get_hybrid_layer_counts, + get_hybrid_total_layer_count, + get_hybrid_total_pipeline_segment_count, parse_hybrid_pattern, + pattern_from_ratios, + select_pipeline_segment, + validate_segment_layers, ) @pytest.mark.internal -class TestMambaHybridLayerAllocation: +class TestPatternFromRatios: - def test_hybrid_layer_allocation(self): - # The format for the test cases is: - # (layers_count, attention_ratio, mlp_ratio, override_pattern). + def test_pure_mamba(self): + result = pattern_from_ratios(8, attention_ratio=0.0, mlp_ratio=0.0) + assert result == "MMMMMMMM" + + def test_attention_only(self): + result = pattern_from_ratios(10, attention_ratio=0.3) + assert result.count(Symbols.ATTENTION) == 3 + assert result.count(Symbols.MAMBA) == 7 + assert len(result) == 10 + + def test_attention_and_mlp(self): + result = pattern_from_ratios(10, attention_ratio=0.3, mlp_ratio=0.3) + assert result.count(Symbols.ATTENTION) == 3 + assert result.count(Symbols.MLP) == 3 + assert result.count(Symbols.MAMBA) == 4 + assert len(result) == 10 + + def test_attention_evenly_spaced(self): + result = pattern_from_ratios(10, attention_ratio=0.5) + assert result.count(Symbols.ATTENTION) == 5 + assert result.count(Symbols.MAMBA) == 5 + attn_positions = [i for i, ch in enumerate(result) if ch == Symbols.ATTENTION] + gaps = [attn_positions[i + 1] - attn_positions[i] for i in range(len(attn_positions) - 1)] + assert all( + g in (1, 2, 3) for g in gaps + ), f"Gaps between attention layers should be small, got {gaps}" + + def test_mlp_does_not_replace_attention(self): + result = pattern_from_ratios(10, attention_ratio=0.3, mlp_ratio=0.3) + attn_positions = [i for i, c in enumerate(result) if c == Symbols.ATTENTION] + mlp_positions = [i for i, c in enumerate(result) if c == Symbols.MLP] + assert not set(attn_positions) & set(mlp_positions) + + def test_single_layer(self): + assert pattern_from_ratios(1, 0.0, 0.0) == "M" + assert pattern_from_ratios(1, 1.0, 0.0) == "*" + + def test_returns_string(self): + result = pattern_from_ratios(4, 0.5) + assert isinstance(result, str) + + +@pytest.mark.internal +class TestValidateSegmentLayers: + + def test_valid_patterns(self): + """Test that valid segment patterns produce the correct layer type lists.""" test_cases = [ - (9, 0.0, 0.0, "M*-M*-M*-"), - (9, 0.0, 0.0, "MMMMMMMMM"), - (30, 0.0, 0.0, None), - (8, 0.25, 0.25, "MM*-MM*-"), - (8, 0.5, 0.25, "M**-M**-"), - (48, 0.5, 0.2, None), + ("M*-M*-M*-", ['M', '*', '-', 'M', '*', '-', 'M', '*', '-']), + ("MMMMMMMMM", ['M'] * 9), + ("MM*-MM*-", ['M', 'M', '*', '-', 'M', 'M', '*', '-']), + ("E", ['E']), + ("", []), ] - for test in test_cases: - (layers_count, attention_ratio, mlp_ratio, override_pattern) = test + for pattern, expected in test_cases: + result = validate_segment_layers(pattern) + assert result == expected, f"Failed for pattern: {pattern}" - layer_types = allocate_layers(*test) + def test_all_valid_symbols(self): + """Make sure all returned layers are valid.""" + for pattern in ["M*-M*-M*-", "MMMMMMMMM", "MM*-", "MEME"]: + layer_types = validate_segment_layers(pattern) + for layer_type in layer_types: + assert layer_type in Symbols.VALID_LAYERS - # Check that return value is in the right format. - assert isinstance(layer_types, list) - assert layers_count == len(layer_types) + def test_invalid_symbols_cause_failure(self): + """Test that invalid symbols raise ValueError.""" + with pytest.raises(ValueError): + validate_segment_layers("M*X") + with pytest.raises(ValueError): + validate_segment_layers("M|M") # pipe not valid in a segment + with pytest.raises(ValueError): + validate_segment_layers("M/M") # MTP separator not valid in a segment - # Make sure all the layers are valid. - for layer_type in layer_types: - assert layer_type in Symbols.VALID - - # Make sure each layer is as requested by override_pattern. - if override_pattern is not None: - assert len(override_pattern) == len(layer_types) - for index, layer_type in enumerate(layer_types): - assert override_pattern[index] == layer_types[index] - else: - # Make sure the count of each type of layer is correct. - counts = {layer_type: 0 for layer_type in Symbols.VALID} # Initialize all to zero. - for layer_type in layer_types: - assert layer_type in counts - counts[layer_type] += 1 - # Check the ratios. - remainder = 1.0 - attention_ratio - mlp_ratio - assert remainder >= 0 - assert int(attention_ratio * layers_count + 0.5) == counts[Symbols.ATTENTION] - assert int(mlp_ratio * layers_count + 0.5) == counts[Symbols.MLP] - assert int(remainder * layers_count + 0.5) == counts[Symbols.MAMBA] - - # Make sure the ratios are as requested. - # This code is not working yet because capsys seems broken in Megatron. - # captured = capsys.readouterr() # Remove this output from the capture buffer. - # out = captured.out # Get stdout. - # if attention_ratio != 0 or mlp_ratio != 0: - # assert ( - # match := re.search(r'Actual attention ratio: (1\.0|0\.[0-9]+)\.', out) - # ) and math.isclose(match.group(1), attention_ratio) - # assert ( - # match := re.search(r'Actual mlp ratio: (1\.0|0\.[0-9]+)\.', out) - # ) and math.isclose(match.group(1), mlp_ratio) - - @pytest.mark.xfail(raises=ValueError) - def test_wrong_length_override_pattern(self): - # This override_pattern is too short. - layer_types = allocate_layers(9, 0.0, 0.0, "M*-M*-") - - @pytest.mark.xfail(raises=ValueError) - def test_wrong_number_of_layer_types_in_override_pattern(self): - # This override_pattern has too many mlps and not enough attention - layer_types = allocate_layers(8, 0.5, 0.25, "M*--M**-") + +@pytest.mark.internal +class TestGetHybridTotalLayerCount: + + def test_simple_patterns(self): + assert get_hybrid_total_layer_count("M*M*") == 4 + assert get_hybrid_total_layer_count("MMMM") == 4 + assert get_hybrid_total_layer_count("M") == 1 + + def test_with_pipe_separators(self): + assert get_hybrid_total_layer_count("M-M-|M-M*-") == 9 + assert get_hybrid_total_layer_count("M-M-|M-M*-|M-M-|M-M*-") == 18 + assert get_hybrid_total_layer_count("||M") == 1 + assert get_hybrid_total_layer_count("M|M") == 2 + + def test_with_mtp(self): + assert get_hybrid_total_layer_count("M*M*/MM/MM") == 4 + assert get_hybrid_total_layer_count("M-M-|M-M*-/MM/MM") == 9 + + def test_empty(self): + assert get_hybrid_total_layer_count("") == 0 + + +@pytest.mark.internal +class TestGetHybridTotalPipelineSegmentCount: + + def test_no_pipe(self): + assert get_hybrid_total_pipeline_segment_count("M*M*") == 1 + + def test_with_pipes(self): + assert get_hybrid_total_pipeline_segment_count("M-M-|M-M*-") == 2 + assert get_hybrid_total_pipeline_segment_count("M|M|M|M") == 4 + assert get_hybrid_total_pipeline_segment_count("||M") == 3 + + def test_with_mtp(self): + assert get_hybrid_total_pipeline_segment_count("M-M-|M-M*-/MM/MM") == 2 @pytest.mark.internal @@ -108,6 +158,15 @@ def test_main_pattern_only(self): assert result.mtp_pattern is None assert result.mtp_num_depths == 0 + def test_main_pattern_with_pipes(self): + """Test patterns with pipe separators (no MTP).""" + test_cases = [("M*|M*", "M*|M*"), ("M-M-|M-M*-", "M-M-|M-M*-"), ("M|M|M|M", "M|M|M|M")] + for pattern, expected_main in test_cases: + result = parse_hybrid_pattern(pattern) + assert result.main_pattern == expected_main, f"Failed for pattern: {pattern}" + assert result.mtp_pattern is None + assert result.mtp_num_depths == 0 + def test_main_with_single_mtp_depth(self): """Test patterns with 1 MTP depth.""" test_cases = [ @@ -136,6 +195,13 @@ def test_main_with_multiple_mtp_depths(self): assert result.mtp_pattern == expected_mtp, f"Failed for pattern: {pattern}" assert result.mtp_num_depths == expected_depths, f"Failed for pattern: {pattern}" + def test_pipe_with_mtp(self): + """Test patterns with both pipe and MTP separators.""" + result = parse_hybrid_pattern("M-M-|M-M*-/MM/MM") + assert result.main_pattern == "M-M-|M-M*-" + assert result.mtp_pattern == "MM" + assert result.mtp_num_depths == 2 + def test_mtp_patterns_must_be_identical(self): """Test that mismatched MTP patterns raise ValueError.""" invalid_patterns = [ @@ -173,6 +239,11 @@ def test_invalid_symbols_in_mtp_pattern(self): with pytest.raises(ValueError, match="All MTP patterns must be identical"): parse_hybrid_pattern("M*M*/MM/Ma") + def test_pipe_not_allowed_in_mtp(self): + """Test that pipe symbol in MTP pattern raises ValueError.""" + with pytest.raises(ValueError, match="not a valid layer symbol"): + parse_hybrid_pattern("M*M*/M|M/M|M") + def test_empty_main_pattern_with_mtp(self): """Test pattern that starts with / (empty main pattern).""" result = parse_hybrid_pattern("/MM/MM") @@ -212,3 +283,129 @@ def test_dataclass_equality(self): p1 = parse_hybrid_pattern("M*M*/MM/MM") p2 = ParsedHybridPattern(main_pattern="M*M*", mtp_pattern="MM", mtp_num_depths=2) assert p1 == p2 + + +@pytest.mark.internal +class TestGetHybridLayerCounts: + + def test_simple_pattern(self): + assert get_hybrid_layer_counts("M*M*") == {'*': 2, 'M': 2, '-': 0, 'E': 0} + + def test_all_layer_types(self): + assert get_hybrid_layer_counts("M*-E") == {'*': 1, 'M': 1, '-': 1, 'E': 1} + + def test_with_pipes(self): + # Pipes should be skipped in counting + assert get_hybrid_layer_counts("M*|M*") == {'*': 2, 'M': 2, '-': 0, 'E': 0} + assert get_hybrid_layer_counts("M-M-|M-M*-") == {'*': 1, 'M': 4, '-': 4, 'E': 0} + + def test_with_mtp(self): + # MTP pattern "MM" repeated 2 depths -> 4 extra mamba layers + assert get_hybrid_layer_counts("M*M*/MM/MM") == {'*': 2, 'M': 6, '-': 0, 'E': 0} + + def test_with_pipes_and_mtp(self): + # Main: M-M-|M-M*- -> 1 attn, 4 mamba, 4 mlp + # MTP: MM x 2 depths -> +4 mamba + assert get_hybrid_layer_counts("M-M-|M-M*-/MM/MM") == {'*': 1, 'M': 8, '-': 4, 'E': 0} + + def test_moe_pattern(self): + assert get_hybrid_layer_counts("MEME") == {'*': 0, 'M': 2, '-': 0, 'E': 2} + + def test_mtp_with_attention(self): + # MTP pattern "*M" repeated 3 depths -> 3 attn + 3 mamba from MTP + assert get_hybrid_layer_counts("MMMM/*M/*M/*M") == {'*': 3, 'M': 7, '-': 0, 'E': 0} + + def test_empty_pattern(self): + assert get_hybrid_layer_counts("") == {'*': 0, 'M': 0, '-': 0, 'E': 0} + + +@pytest.mark.internal +class TestSelectPipelineSegment: + """Tests for select_pipeline_segment with pp_group=None (single rank). + + When pp_group is None, pp_rank=0 and pp_size=1, so the segment index + is simply the vp_stage value. + """ + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_single_segment_no_vp(self, mock_log): + """Single segment, no VPP.""" + layer_types, offset = select_pipeline_segment("M*M*", pp_group=None, vp_stage=None) + assert layer_types == ['M', '*', 'M', '*'] + assert offset == 0 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_two_segments_vp0(self, mock_log): + """Two segments, select first (vp_stage=0).""" + layer_types, offset = select_pipeline_segment("M-M-|M-M*-", pp_group=None, vp_stage=0) + assert layer_types == ['M', '-', 'M', '-'] + assert offset == 0 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_two_segments_vp1(self, mock_log): + """Two segments, select second (vp_stage=1).""" + layer_types, offset = select_pipeline_segment("M-M-|M-M*-", pp_group=None, vp_stage=1) + assert layer_types == ['M', '-', 'M', '*', '-'] + assert offset == 4 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_four_segments(self, mock_log): + """Four segments, verify each vp_stage selects correctly.""" + pattern = "MM|M*|M-|ME" + expected = [(['M', 'M'], 0), (['M', '*'], 2), (['M', '-'], 4), (['M', 'E'], 6)] + for vp_stage, (expected_layers, expected_offset) in enumerate(expected): + layer_types, offset = select_pipeline_segment(pattern, pp_group=None, vp_stage=vp_stage) + assert layer_types == expected_layers, f"Failed for vp_stage={vp_stage}" + assert offset == expected_offset, f"Failed for vp_stage={vp_stage}" + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_empty_segment(self, mock_log): + """Empty segments are allowed for pipeline balancing.""" + layer_types, offset = select_pipeline_segment("||M*", pp_group=None, vp_stage=0) + assert layer_types == [] + assert offset == 0 + + layer_types, offset = select_pipeline_segment("||M*", pp_group=None, vp_stage=2) + assert layer_types == ['M', '*'] + assert offset == 0 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_uneven_segments(self, mock_log): + """Segments of different lengths.""" + pattern = "MMM|M|MMMMM" + layer_types, offset = select_pipeline_segment(pattern, pp_group=None, vp_stage=0) + assert len(layer_types) == 3 + assert offset == 0 + + layer_types, offset = select_pipeline_segment(pattern, pp_group=None, vp_stage=1) + assert len(layer_types) == 1 + assert offset == 3 + + layer_types, offset = select_pipeline_segment(pattern, pp_group=None, vp_stage=2) + assert len(layer_types) == 5 + assert offset == 4 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_empty_main_pattern(self, mock_log): + """Empty main pattern produces one empty segment.""" + layer_types, offset = select_pipeline_segment("", pp_group=None, vp_stage=None) + assert layer_types == [] + assert offset == 0 + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_invalid_segment_raises(self, mock_log): + """Invalid layer symbols in a segment should raise ValueError.""" + with pytest.raises(ValueError): + select_pipeline_segment("MX|M*", pp_group=None, vp_stage=0) + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_out_of_range_segment_raises(self, mock_log): + """Segment index out of range should raise IndexError.""" + with pytest.raises(IndexError): + select_pipeline_segment("M*|M*", pp_group=None, vp_stage=5) + + @patch('megatron.core.ssm.mamba_hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_logging_is_called(self, mock_log): + """Verify that log_on_each_pipeline_stage is called.""" + select_pipeline_segment("M*M*", pp_group=None, vp_stage=None) + mock_log.assert_called_once() diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index 325994cbf89..1580cd7cc41 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -23,6 +23,7 @@ from megatron.core.pipeline_parallel.schedules import set_current_microbatch from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.mamba_block import MambaStack +from megatron.core.ssm.mamba_hybrid_layer_allocation import validate_segment_layers from megatron.core.tensor_parallel.random import ( HAVE_TE, initialize_rng_tracker, @@ -472,12 +473,13 @@ def setup_method(self, method): def get_pg_collection(): return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) - def get_mamba_block(hybrid_override_pattern): + def get_mamba_block(hybrid_layer_pattern): + layer_type_list = validate_segment_layers(hybrid_layer_pattern) transformer_config = TransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig # will generate errors. - num_layers=len(hybrid_override_pattern), + num_layers=len(layer_type_list), num_attention_heads=4, use_cpu_initialization=True, cuda_graph_impl="local", @@ -486,11 +488,12 @@ def get_mamba_block(hybrid_override_pattern): return MambaStack( transformer_config, modules, - hybrid_override_pattern=hybrid_override_pattern, + layer_type_list=layer_type_list, + pp_layer_offset=0, pg_collection=get_pg_collection(), ) - self.mamba_block = get_mamba_block(hybrid_override_pattern="M-M*-") + self.mamba_block = get_mamba_block(hybrid_layer_pattern="M-M*-") self.transformer_config = self.mamba_block.config def teardown_method(self, method): diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index ec72d713eb1..c5fd9688505 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -705,7 +705,7 @@ def teardown_method(self, method): def model_provider(self, pre_process=True, post_process=True, **config_kwargs): """Model provider for Mamba hybrid models with MTP. - Uses the unified pattern syntax where MTP is configured via hybrid_override_pattern: + Uses the unified pattern syntax where MTP is configured via hybrid_layer_pattern: Format: "///..." Example: "M*M*/M*/M*" = main decoder "M*M*", MTP pattern "M*" with 2 depths """ @@ -713,7 +713,7 @@ def model_provider(self, pre_process=True, post_process=True, **config_kwargs): args = get_args() config = core_transformer_config_from_args(args) - # MTP is configured via unified pattern in hybrid_override_pattern + # MTP is configured via unified pattern in hybrid_layer_pattern # MambaModel creates the MTP block internally based on the parsed pattern model = MambaModel( config=config, @@ -722,9 +722,7 @@ def model_provider(self, pre_process=True, post_process=True, **config_kwargs): max_sequence_length=args.max_position_embeddings, pre_process=pre_process, post_process=post_process, - hybrid_attention_ratio=args.hybrid_attention_ratio, - hybrid_mlp_ratio=args.hybrid_mlp_ratio, - hybrid_override_pattern=args.hybrid_override_pattern, + hybrid_layer_pattern=args.hybrid_layer_pattern, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, parallel_output=True, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, @@ -741,7 +739,6 @@ def create_test_args( sys.argv = ['test_multi_token_prediction_mamba.py'] args = parse_args() - args.num_layers = 4 args.mtp_num_layers = 2 args.mtp_loss_scaling_factor = 0.1 args.vocab_size = 128800 @@ -767,10 +764,8 @@ def create_test_args( args.no_load_optim = True args.no_load_rng = True args.bf16 = True - args.hybrid_attention_ratio = 0.5 - args.hybrid_mlp_ratio = 0.0 # Unified pattern: "main/mtp/mtp" - main decoder "M*M*", MTP pattern "M*" with 2 depths - args.hybrid_override_pattern = "M*M*/M*/M*" + args.hybrid_layer_pattern = "M*M*/M*/M*" args.spec = "megatron.core.models.mamba.mamba_layer_specs.mamba_stack_spec" if fp8 is not None: diff --git a/train_rl.py b/train_rl.py index 645e78ba986..50013a9ce1f 100644 --- a/train_rl.py +++ b/train_rl.py @@ -22,6 +22,7 @@ load_packed_data_by_index, ) from megatron.training import get_args, get_timers, pretrain, print_rank_0 +from megatron.training.utils import is_hybrid_model from megatron.training.arguments import core_transformer_config_from_args from model_provider import model_provider @@ -377,7 +378,7 @@ def __getitem__(self, idx): def _model_builder( args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None ): - if getattr(args, "is_hybrid_model", False): + if is_hybrid_model(args): return mamba_builder( args, pre_process,