From 225a529516eab40c7779303283393e7aeef1e44c Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 00:28:28 +0100 Subject: [PATCH 1/7] Port DeepSeek Sparse Attention to `MambaModel` --- .../core/models/mamba/mamba_layer_specs.py | 82 +++++++++++++++++++ megatron/core/ssm/mamba_block.py | 10 +++ .../core/ssm/mamba_hybrid_layer_allocation.py | 8 +- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 6ca628475be..620bfa3ffa5 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -4,6 +4,7 @@ TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, + TELinear, TENorm, TERowParallelLinear, ) @@ -19,7 +20,18 @@ ) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.experimental_attention_variant.dsa import ( + DSAIndexer, + DSAIndexerSubmodules, + DSAttention, + DSAttentionSubmodules, +) +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, MultiTokenPredictionBlockSubmodules, @@ -96,6 +108,41 @@ self_attn_bda=get_bias_dropout_add, ), ), + dsa_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TEColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TEColumnParallelLinear, + core_attention=ModuleSpec( + module=DSAttention, + submodules=DSAttentionSubmodules( + indexer=ModuleSpec( + module=DSAIndexer, + submodules=DSAIndexerSubmodules( + linear_wq_b=TELinear, + linear_wk=TELinear, + k_norm=TENorm, + linear_weights_proj=TELinear, + ), + ) + ), + ), + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), # Started with spec from gpt_layer_specs.py # Using the TE spec because we had problems getting the non-TE spec # working @@ -156,6 +203,41 @@ self_attn_bda=get_bias_dropout_add, ), ), + dsa_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TEColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TEColumnParallelLinear, + core_attention=ModuleSpec( + module=DSAttention, + submodules=DSAttentionSubmodules( + indexer=ModuleSpec( + module=DSAIndexer, + submodules=DSAIndexerSubmodules( + linear_wq_b=TELinear, + linear_wk=TELinear, + k_norm=TENorm, + linear_weights_proj=TELinear, + ), + ) + ), + ), + linear_proj=InferenceRowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), # Started with spec from gpt_layer_specs.py # Using the TE spec because we had problems getting the non-TE spec # working diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index ef67983d4cf..fbcd4fa7d7e 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -41,6 +41,7 @@ class MambaStackSubmodules: mamba_layer: Union[ModuleSpec, type] = IdentityOp attention_layer: Union[ModuleSpec, type] = IdentityOp + dsa_layer: Union[ModuleSpec, type] = IdentityOp mlp_layer: Union[ModuleSpec, type] = IdentityOp moe_layer: Union[ModuleSpec, type] = IdentityOp mtp_block_spec: Optional[ModuleSpec] = None @@ -160,6 +161,15 @@ def __init__( pg_collection=pg_collection, is_mtp_layer=is_mtp_layer, ) + elif layer_type == LayerSymbols.DSA_ATTENTION: + # DSA attention layers apply their own pp_layer_offset + layer = build_module( + submodules.dsa_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, + ) elif layer_type == LayerSymbols.MLP: # MLP layers apply their own pp_layer_offset layer = build_module( diff --git a/megatron/core/ssm/mamba_hybrid_layer_allocation.py b/megatron/core/ssm/mamba_hybrid_layer_allocation.py index d7002b2915d..be89e4f7ac9 100644 --- a/megatron/core/ssm/mamba_hybrid_layer_allocation.py +++ b/megatron/core/ssm/mamba_hybrid_layer_allocation.py @@ -28,10 +28,11 @@ class Symbols: MAMBA = "M" ATTENTION = "*" + DSA_ATTENTION = "S" MLP = "-" MOE = 'E' MTP_SEPARATOR = "/" - VALID = {MAMBA, ATTENTION, MLP, MOE} + VALID = {MAMBA, ATTENTION, DSA_ATTENTION, MLP, MOE} @dataclass @@ -264,6 +265,7 @@ def allocate_layers( logging.INFO, f"Hybrid allocation ({Symbols.MAMBA} is mamba, " f"{Symbols.ATTENTION} is attention, " + f"{Symbols.DSA_ATTENTION} is dsa attention, " f"{Symbols.MLP} is mlp):", ) maybe_log_single_rank(logger, logging.INFO, allocation_string) @@ -303,7 +305,9 @@ def get_layer_maps_from_layer_type_list( layer_types = [Symbols.ATTENTION, Symbols.MAMBA, Symbols.MLP, Symbols.MOE] layer_maps = {layer_type: {} for layer_type in layer_types} for global_layer_idx, layer_type in enumerate(layer_type_list): - layer_map = layer_maps[layer_type] + # DSA attention layers are treated as attention for KV cache mapping. + effective_type = Symbols.ATTENTION if layer_type == Symbols.DSA_ATTENTION else layer_type + layer_map = layer_maps[effective_type] local_layer_idx = len(layer_map) layer_map[global_layer_idx] = local_layer_idx return [layer_maps[layer_type] for layer_type in layer_types] From 9429a63ef6025e4ad6b27bc1e5a8923c9da0b8b4 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 01:08:40 +0100 Subject: [PATCH 2/7] Add DSA Mamba tests DSA = DeepSeek Sparse Attention --- tests/unit_tests/ssm/test_mamba_block.py | 54 +++++++++++++++++++ .../ssm/test_mamba_hybrid_layer_allocation.py | 50 +++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/tests/unit_tests/ssm/test_mamba_block.py b/tests/unit_tests/ssm/test_mamba_block.py index 909ee47e836..b42267c4316 100644 --- a/tests/unit_tests/ssm/test_mamba_block.py +++ b/tests/unit_tests/ssm/test_mamba_block.py @@ -12,6 +12,8 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.multi_latent_attention import MLASelfAttention +from megatron.core.transformer.transformer_config import MLATransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from tests.unit_tests.test_utilities import Utils @@ -43,6 +45,34 @@ def get_mamba_block(self, hybrid_override_pattern): pg_collection=self.get_pg_collection(), ) + def get_dsa_mamba_block(self, hybrid_override_pattern): + config = MLATransformerConfig( + num_layers=len(hybrid_override_pattern), + hidden_size=256, + num_attention_heads=16, + use_cpu_initialization=True, + bf16=True, + params_dtype=torch.bfloat16, + q_lora_rank=64, + kv_lora_rank=64, + qk_head_dim=64, + qk_pos_emb_head_dim=32, + v_head_dim=64, + rope_type='rope', + rotary_base=10000, + rotary_percent=1.0, + dsa_indexer_n_heads=8, + dsa_indexer_head_dim=64, + dsa_indexer_topk=32, + ) + modules = mamba_stack_spec.submodules + return MambaStack( + config, + modules, + hybrid_override_pattern=hybrid_override_pattern, + pg_collection=self.get_pg_collection(), + ) + def teardown_method(self, method): Utils.destroy_model_parallel() @@ -87,3 +117,27 @@ def test_invalid_layer_types_cause_failure(self): # _allocate_override() in mamba_hybrid_layer_allocation.py throws a ValueError. with pytest.raises(ValueError): block = self.get_mamba_block(hybrid_override_pattern) + + def test_dsa_layer_types(self): + """S symbol creates a TransformerLayer with MLASelfAttention.""" + pattern = Symbols.MAMBA + Symbols.DSA_ATTENTION + Symbols.MAMBA + block = self.get_dsa_mamba_block(pattern) + layers = block.layers + assert isinstance(layers[0], MambaLayer) + assert isinstance(layers[1], TransformerLayer) + assert isinstance(layers[1].self_attention, MLASelfAttention) + assert isinstance(layers[1].self_attention.core_attention, DSAttention) + assert isinstance(layers[2], MambaLayer) + + def test_mixed_attention_and_dsa_layer_types(self): + """* and S in the same block create different attention types.""" + pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.DSA_ATTENTION + Symbols.MAMBA + block = self.get_dsa_mamba_block(pattern) + layers = block.layers + assert isinstance(layers[0], MambaLayer) + assert isinstance(layers[1], TransformerLayer) + assert isinstance(layers[1].self_attention, SelfAttention) + assert isinstance(layers[2], TransformerLayer) + assert isinstance(layers[2].self_attention, MLASelfAttention) + assert isinstance(layers[2].self_attention.core_attention, DSAttention) + assert isinstance(layers[3], MambaLayer) diff --git a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py index 77c106c3bee..f4b7658fa95 100644 --- a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py @@ -10,6 +10,7 @@ ParsedHybridPattern, Symbols, allocate_layers, + get_layer_maps_from_layer_type_list, parse_hybrid_pattern, ) @@ -27,6 +28,8 @@ def test_hybrid_layer_allocation(self): (8, 0.25, 0.25, "MM*-MM*-"), (8, 0.5, 0.25, "M**-M**-"), (48, 0.5, 0.2, None), + (4, 0.0, 0.0, "MSMS"), + (5, 0.0, 0.0, "MSM*-"), ] for test in test_cases: (layers_count, attention_ratio, mlp_ratio, override_pattern) = test @@ -101,6 +104,8 @@ def test_main_pattern_only(self): ("*M*M", "*M*M"), ("MM-*", "MM-*"), ("E", "E"), + ("MSMS", "MSMS"), + ("SM", "SM"), ] for pattern, expected_main in test_cases: result = parse_hybrid_pattern(pattern) @@ -200,6 +205,8 @@ def test_complex_patterns(self): ("*****/M/M/M/M", "*****", "M", 4), # MoE in main pattern ("MEME/MM/MM", "MEME", "MM", 2), + # DSA in main pattern with MTP + ("MSMS/MS/MS", "MSMS", "MS", 2), ] for pattern, expected_main, expected_mtp, expected_depths in test_cases: result = parse_hybrid_pattern(pattern) @@ -212,3 +219,46 @@ def test_dataclass_equality(self): p1 = parse_hybrid_pattern("M*M*/MM/MM") p2 = ParsedHybridPattern(main_pattern="M*M*", mtp_pattern="MM", mtp_num_depths=2) assert p1 == p2 + + +@pytest.mark.internal +class TestGetLayerMapsFromLayerTypeList: + """Tests for get_layer_maps_from_layer_type_list.""" + + def test_standard_layer_types(self): + """Standard symbols each produce a single-entry map at local index 0.""" + maps = get_layer_maps_from_layer_type_list(["*", "M", "-", "E"]) + assert len(maps) == 4 + attention_map, mamba_map, mlp_map, moe_map = maps + assert attention_map == {0: 0} + assert mamba_map == {1: 0} + assert mlp_map == {2: 0} + assert moe_map == {3: 0} + + def test_dsa_maps_to_attention(self): + """S (DSA) layers are treated as attention for KV cache mapping.""" + maps = get_layer_maps_from_layer_type_list(["S", "M", "S", "M"]) + attention_map, mamba_map, mlp_map, moe_map = maps + # S at global indices 0 and 2 land in the attention map + assert attention_map == {0: 0, 2: 1} + assert mamba_map == {1: 0, 3: 1} + assert mlp_map == {} + assert moe_map == {} + + def test_mixed_attention_and_dsa(self): + """Both * and S contribute to the attention map with consecutive local indices.""" + maps = get_layer_maps_from_layer_type_list(["*", "S", "M", "-"]) + attention_map, mamba_map, mlp_map, moe_map = maps + assert attention_map == {0: 0, 1: 1} + assert mamba_map == {2: 0} + assert mlp_map == {3: 0} + assert moe_map == {} + + def test_all_mamba(self): + """All-mamba pattern leaves attention, mlp, and moe maps empty.""" + maps = get_layer_maps_from_layer_type_list(["M", "M", "M"]) + attention_map, mamba_map, mlp_map, moe_map = maps + assert attention_map == {} + assert mamba_map == {0: 0, 1: 1, 2: 2} + assert mlp_map == {} + assert moe_map == {} From 70d02a7d1d630fcafd55117cba656aebbfe14922 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 01:09:00 +0100 Subject: [PATCH 3/7] Fix DSA dispatch And add corresponding tests. DSA = DeepSeek Sparse Attention --- ...rimental_attention_variant_module_specs.py | 3 + .../transformer/test_attention_variant_dsa.py | 73 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index a7cc7cc0a55..45fcb17d0e4 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -123,6 +123,7 @@ def get_dsa_module_spec_for_backend( q_layernorm=IdentityOp, kv_layernorm=IdentityOp, ), + metainfo={"fuse_input_layernorm": False}, ) return attention @@ -138,6 +139,8 @@ def get_experimental_attention_variant_module_spec( if config.experimental_attention_variant == "gated_delta_net": return get_gated_delta_net_module_spec(config=config, backend=backend) + elif config.experimental_attention_variant == "dsa": + return get_dsa_module_spec_for_backend(config=config, backend=backend) else: raise ValueError( f"Invalid experimental attention variant: {config.experimental_attention_variant}" diff --git a/tests/unit_tests/transformer/test_attention_variant_dsa.py b/tests/unit_tests/transformer/test_attention_variant_dsa.py index 96253a4ca10..b77e338e035 100644 --- a/tests/unit_tests/transformer/test_attention_variant_dsa.py +++ b/tests/unit_tests/transformer/test_attention_variant_dsa.py @@ -6,6 +6,10 @@ import torch import megatron.core.parallel_state as parallel_state +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_dsa_module_spec_for_backend, + get_experimental_attention_variant_module_spec, +) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -23,6 +27,7 @@ fused_qk_topk_naive, rotate_activation, ) +from megatron.core.transformer.multi_latent_attention import MLASelfAttention from megatron.core.transformer.transformer_config import MLATransformerConfig from tests.unit_tests.test_utilities import Utils @@ -1586,3 +1591,71 @@ def test_dsa_gradient_sync( ), f"Indexer gradient for {name} differs between TP rank 0 and rank {i} after TP sync" Utils.destroy_model_parallel() + + +@pytest.mark.internal +class TestDSAModuleSpecDispatch: + """Tests for get_dsa_module_spec_for_backend and get_experimental_attention_variant_module_spec.""" + + @pytest.fixture(scope='function', autouse=True) + def setup_method(self): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + yield + Utils.destroy_model_parallel() + + def _make_dsa_config(self, **kwargs): + return MLATransformerConfig( + num_layers=2, + hidden_size=256, + num_attention_heads=16, + use_cpu_initialization=True, + bf16=True, + params_dtype=torch.bfloat16, + q_lora_rank=64, + kv_lora_rank=64, + qk_head_dim=64, + qk_pos_emb_head_dim=32, + v_head_dim=64, + rope_type='rope', + rotary_base=10000, + rotary_percent=1.0, + dsa_indexer_n_heads=8, + dsa_indexer_head_dim=64, + dsa_indexer_topk=32, + **kwargs, + ) + + def test_get_experimental_attention_variant_module_spec_dsa(self): + """get_experimental_attention_variant_module_spec dispatches to DSA for variant='dsa'.""" + config = self._make_dsa_config(experimental_attention_variant="dsa") + spec = get_experimental_attention_variant_module_spec(config) + assert spec.module == MLASelfAttention + assert spec.submodules.core_attention.module == DSAttention + + def test_get_dsa_module_spec_for_backend(self): + """get_dsa_module_spec_for_backend returns the correct full spec structure.""" + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + + config = self._make_dsa_config() + backend = TESpecProvider() + spec = get_dsa_module_spec_for_backend(config, backend=backend) + assert spec.module == MLASelfAttention + assert spec.submodules.core_attention.module == DSAttention + assert spec.submodules.core_attention.submodules.indexer.module == DSAIndexer + assert spec.params["attn_mask_type"] == AttnMaskType.causal + + def test_get_dsa_module_spec_requires_mla(self): + """get_dsa_module_spec_for_backend rejects configs without MLA.""" + from megatron.core.transformer import TransformerConfig as _TransformerConfig + + config = _TransformerConfig(num_layers=2, hidden_size=256, num_attention_heads=4) + with pytest.raises(AssertionError, match="only MLA supports sparse attention"): + get_dsa_module_spec_for_backend(config, backend=None) + + def test_get_dsa_module_spec_rejects_qk_l2_norm(self): + """get_dsa_module_spec_for_backend rejects configs with qk_l2_norm=True.""" + config = self._make_dsa_config(qk_l2_norm=True) + with pytest.raises(AssertionError, match="qk_l2_norm is not supported"): + get_dsa_module_spec_for_backend(config, backend=None) From 43e6625f8193ac93c2bc42c28995e9ceba7b9e78 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 18:55:36 +0100 Subject: [PATCH 4/7] Add DSA GPT/Mamba logprob equivalence tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New pytest test `test_dsa_gpt_mamba_equivalence.py` builds both a GPTModel (DSA, 4 layers) and a MambaModel (pattern S-S-S-S-, 8 layers) in-memory, remaps weights GPT→Mamba, and asserts logprob equivalence across TP=1/PP=1, TP=2/PP=1, and TP=1/PP=2 distributed configs. - New checkpoint conversion utility `tools/checkpoint/remap_gpt_dsa_to_mamba.py` applies the same layer-key remapping (decoder.layers.{N} → {2N}/{2N+1}, decoder.final_layernorm → decoder.final_norm) to DCP checkpoints. - New functional test cases for CI: hybrid_dsa_mamba_logitsmatch_tp1_pp1 and _tp2_pp1, each with model_config.yaml (MambaModel inference) and placeholder golden values. - New CI recipe `tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml` wiring the two functional test cases to the h100 pipeline. --- .../golden_values_dev_dgx_h100.json | 1 + .../model_config.yaml | 49 ++ .../golden_values_dev_dgx_h100.json | 1 + .../model_config.yaml | 49 ++ .../h100/mamba-dsa-static-inference.yaml | 67 +++ .../models/test_dsa_gpt_mamba_equivalence.py | 493 ++++++++++++++++++ tools/checkpoint/remap_gpt_dsa_to_mamba.py | 171 ++++++ 7 files changed, 831 insertions(+) create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/model_config.yaml create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/model_config.yaml create mode 100644 tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml create mode 100644 tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py create mode 100644 tools/checkpoint/remap_gpt_dsa_to_mamba.py diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..701447a7a4c --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json @@ -0,0 +1 @@ +{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}} diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/model_config.yaml new file mode 100644 index 00000000000..187062b565a --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp1_pp1/model_config.yaml @@ -0,0 +1,49 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: ":4096:8" +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --use-mcore-models: true + --is-hybrid-model: true + --model-provider: mamba + --disable-bias-linear: true + --position-embedding-type: rope + --multi-latent-attention: true + --q-lora-rank: 64 + --kv-lora-rank: 64 + --qk-head-dim: 64 + --qk-pos-emb-head-dim: 32 + --v-head-dim: 64 + --dsa-indexer-n-heads: 8 + --dsa-indexer-head-dim: 64 + --dsa-indexer-topk: 32 + --num-layers: 8 + --hidden-size: 256 + --num-attention-heads: 16 + --hybrid-override-pattern: "S-S-S-S-" + --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --normalization: RMSNorm + --swiglu: true + --bf16: true + --attention-backend: flash + --deterministic-mode: true + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --inference-max-seq-length: 256 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: "The quick brown fox jumps over the lazy dog." + --incoming-requests-per-sec: -1 +METRICS: + - "generated_tokens" + - "logprobs" diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..701447a7a4c --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json @@ -0,0 +1 @@ +{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}} diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/model_config.yaml new file mode 100644 index 00000000000..86792d13bb6 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_mamba_logitsmatch_tp2_pp1/model_config.yaml @@ -0,0 +1,49 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: ":4096:8" +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --use-mcore-models: true + --is-hybrid-model: true + --model-provider: mamba + --disable-bias-linear: true + --position-embedding-type: rope + --multi-latent-attention: true + --q-lora-rank: 64 + --kv-lora-rank: 64 + --qk-head-dim: 64 + --qk-pos-emb-head-dim: 32 + --v-head-dim: 64 + --dsa-indexer-n-heads: 8 + --dsa-indexer-head-dim: 64 + --dsa-indexer-topk: 32 + --num-layers: 8 + --hidden-size: 256 + --num-attention-heads: 16 + --hybrid-override-pattern: "S-S-S-S-" + --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --normalization: RMSNorm + --swiglu: true + --bf16: true + --attention-backend: flash + --deterministic-mode: true + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --inference-max-seq-length: 256 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: "The quick brown fox jumps over the lazy dog." + --incoming-requests-per-sec: -1 +METRICS: + - "generated_tokens" + - "logprobs" diff --git a/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml b/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml new file mode 100644 index 00000000000..64965755ff2 --- /dev/null +++ b/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml @@ -0,0 +1,67 @@ +type: basic +format_version: 1 +maintainers: [mcore] +loggers: [stdout] +spec: + name: "{test_case}_{environment}_{platforms}" + model: hybrid + build: mcore-pyt-{environment} + nodes: 1 + gpus: 2 + n_repeat: 1 + platforms: dgx_h100 + script_setup: | + unset https_proxy + echo "machine gitlab-master.nvidia.com login okoenig password $RO_API_TOKEN" | tee -a /root/.netrc + + # Checkout latest + cd /opt + rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm + git init + git remote add origin $MCORE_REPO + git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' + git fetch origin $MCORE_MR_COMMIT + git checkout $MCORE_MR_COMMIT + git rev-parse HEAD + # Checkout backwards-ref + cd /opt + rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy + git init + git remote add origin $MCORE_REPO + git fetch origin $MCORE_BACKWARDS_COMMIT + git checkout $MCORE_BACKWARDS_COMMIT + git rev-parse HEAD + rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ + script: |- + ls + cd /opt/megatron-lm + + ARGUMENTS=( + "CHECKPOINT_LOAD_PATH=/mnt/artifacts" + "CHECKPOINT_SAVE_PATH=/tmp/checkpoints" + "DATA_PATH=null" + "DATA_CACHE_PATH=/workspace/data/cache" + "TRAINING_SCRIPT_PATH=examples/inference/gpt/gpt_static_inference.py" + "TRAINING_PARAMS_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/model_config.yaml" + "GOLDEN_VALUES_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/golden_values_{environment}_{platforms}.json" + "OUTPUT_PATH={assets_dir}" + "TENSORBOARD_PATH={assets_dir}/tensorboard" + "INFERENCE_OUTPUT_PATH={assets_dir}/golden_values_{environment}_{platforms}.json" + "N_REPEAT={n_repeat}" + "ENABLE_LIGHTWEIGHT_MODE=${{ENABLE_LIGHTWEIGHT_MODE}}" + "RECORD_CHECKPOINTS=${{RECORD_CHECKPOINTS}}" + ) + + bash ./tests/functional_tests/shell_test_utils/run_ci_test.sh ${{ARGUMENTS[@]}} + +products: + - test_case: [hybrid_dsa_mamba_logitsmatch_tp1_pp1] + products: + - environment: [dev] + scope: [mr-broken, mr-github-broken] + platforms: [dgx_h100] + - test_case: [hybrid_dsa_mamba_logitsmatch_tp2_pp1] + products: + - environment: [dev] + scope: [mr-broken, mr-github-broken] + platforms: [dgx_h100] diff --git a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py new file mode 100644 index 00000000000..27018f67f39 --- /dev/null +++ b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py @@ -0,0 +1,493 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +""" +Equivalence tests: GPTModel with DSA vs MambaModel with DSA pattern. + +A small DeepSeek-V3.2 proxy model (4 GPT layers / 8 Mamba layers) is built, +weights are remapped GPT→Mamba, and logprobs are compared to verify they are +numerically identical. + +Architecture equivalence +------------------------ +GPTModel layer N (combined attention + MLP in one TransformerLayer) + ≡ MambaModel layer 2N (S, DSA TransformerLayer: input_layernorm + MLASelfAttention) + + MambaModel layer 2N+1 (-, MLPLayer: fused-norm MLP) + +Run with:: + + torchrun --nproc-per-node=2 -m pytest \\ + tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py -v +""" + +import copy +import json +import math +import os +from pathlib import Path +from typing import Dict, Optional +from unittest.mock import patch + +import pytest +import torch +import torch.distributed as dist + +import megatron.core.parallel_state as mpu +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import MLATransformerConfig +from megatron.rl.rl_utils import selective_log_softmax +from tests.unit_tests.test_utilities import Utils + +try: + from fast_hadamard_transform import hadamard_transform as _hadamard_transform + + HAVE_HADAMARD = True +except ImportError: + HAVE_HADAMARD = False + + +# --------------------------------------------------------------------------- +# Hadamard mock (used when the library is not installed) +# --------------------------------------------------------------------------- + + +def _mock_hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + """Identity-scale mock for hadamard_transform used in DSA.""" + return x * scale + + +@pytest.fixture(autouse=True) +def _patch_hadamard_if_needed(): + """Patch hadamard_transform in the DSA module when the library is absent.""" + if not HAVE_HADAMARD: + with patch( + 'megatron.core.transformer.experimental_attention_variant.dsa.hadamard_transform', + _mock_hadamard_transform, + ): + yield + else: + yield + + +# --------------------------------------------------------------------------- +# Proxy model constants +# --------------------------------------------------------------------------- + +_VOCAB_SIZE = 256 +_MAX_SEQ_LEN = 64 +_SEQ_LEN = 32 +_BATCH_SIZE = 2 +_NUM_GPT_LAYERS = 4 +_MAMBA_PATTERN = "S-S-S-S-" # len=8 = 2 * _NUM_GPT_LAYERS + + +# --------------------------------------------------------------------------- +# Model construction helpers +# --------------------------------------------------------------------------- + + +def _make_dsa_config(num_layers: int, tp: int = 1) -> MLATransformerConfig: + """Return a small DeepSeek-V3.2 proxy MLATransformerConfig.""" + return MLATransformerConfig( + num_layers=num_layers, + hidden_size=256, + num_attention_heads=16, + q_lora_rank=64, + kv_lora_rank=64, + qk_head_dim=64, + qk_pos_emb_head_dim=32, + v_head_dim=64, + dsa_indexer_n_heads=8, + dsa_indexer_head_dim=64, + dsa_indexer_topk=32, + normalization="RMSNorm", + bf16=True, + params_dtype=torch.bfloat16, + add_bias_linear=False, + use_cpu_initialization=True, + rope_type='rope', + rotary_base=10000, + rotary_percent=1.0, + experimental_attention_variant="dsa", + hidden_dropout=0.0, + attention_dropout=0.0, + tensor_model_parallel_size=tp, + ) + + +def _make_pg_collection() -> ProcessGroupCollection: + return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) + + +def _build_gpt_model( + config: MLATransformerConfig, + pre_process: bool = True, + post_process: bool = True, +) -> GPTModel: + """Build a GPTModel with the DSA transformer block spec.""" + spec = get_transformer_block_with_experimental_attention_variant_spec(config) + model = GPTModel( + config=config, + transformer_layer_spec=spec, + vocab_size=_VOCAB_SIZE, + max_sequence_length=_MAX_SEQ_LEN, + pre_process=pre_process, + post_process=post_process, + parallel_output=False, # Gather logits across TP for easy comparison + position_embedding_type='rope', + rotary_base=10000, + rotary_percent=1.0, + pg_collection=_make_pg_collection(), + ) + return model.cuda() + + +def _build_mamba_model( + config: MLATransformerConfig, + pattern: str, + pre_process: bool = True, + post_process: bool = True, +) -> MambaModel: + """Build a MambaModel with the given hybrid pattern.""" + mamba_config = copy.deepcopy(config) + mamba_config.num_layers = len(pattern) + model = MambaModel( + config=mamba_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=_VOCAB_SIZE, + max_sequence_length=_MAX_SEQ_LEN, + pre_process=pre_process, + post_process=post_process, + parallel_output=False, + hybrid_override_pattern=pattern, + position_embedding_type='rope', + rotary_base=10000, + rotary_percent=1.0, + pg_collection=_make_pg_collection(), + ) + return model.cuda() + + +# --------------------------------------------------------------------------- +# Weight remapping +# --------------------------------------------------------------------------- + + +def _remap_gpt_to_mamba_state_dict( + gpt_sd: Dict[str, torch.Tensor], + num_local_gpt_layers: int, +) -> Dict[str, torch.Tensor]: + """Remap a GPTModel state_dict to a MambaModel state_dict. + + GPTModel layer N (combined attention + MLP) maps to: + * MambaModel layer 2N – DSA attention (input_layernorm + self_attention) + * MambaModel layer 2N+1 – MLP (mlp.*) + + Additionally, ``decoder.final_layernorm.*`` (TransformerBlock naming) is + remapped to ``decoder.final_norm.*`` (MambaStack naming). + + All other keys (embedding, output_layer, rotary_pos_emb, …) are unchanged. + + Args: + gpt_sd: State dict obtained from GPTModel.state_dict(). + num_local_gpt_layers: Number of GPT decoder layers on the current + pipeline stage (i.e. ``len(gpt_model.decoder.layers)``). + + Returns: + Remapped state dict ready for MambaModel.load_state_dict(strict=True). + """ + mamba_sd: Dict[str, torch.Tensor] = {} + layer_prefix = "decoder.layers." + final_ln_prefix = "decoder.final_layernorm." + + for key, value in gpt_sd.items(): + # ---- final layernorm rename ---- + if key.startswith(final_ln_prefix): + suffix = key[len(final_ln_prefix):] + mamba_sd[f"decoder.final_norm.{suffix}"] = value + continue + + # ---- non-layer keys pass through unchanged ---- + if not key.startswith(layer_prefix): + mamba_sd[key] = value + continue + + # ---- parse "decoder.layers.{N}.{rest}" ---- + remainder = key[len(layer_prefix):] + dot_idx = remainder.index('.') + layer_n = int(remainder[:dot_idx]) + rest = remainder[dot_idx + 1:] # e.g. "self_attention.linear_q_proj.weight" + + assert 0 <= layer_n < num_local_gpt_layers, ( + f"Layer index {layer_n} out of range [0, {num_local_gpt_layers}) in key '{key}'" + ) + + if rest.startswith("input_layernorm.") or rest.startswith("self_attention."): + # Attention sub-module → DSA layer 2N + mamba_sd[f"{layer_prefix}{2 * layer_n}.{rest}"] = value + elif rest.startswith("mlp."): + # MLP sub-module → MLP layer 2N+1 + mamba_sd[f"{layer_prefix}{2 * layer_n + 1}.{rest}"] = value + else: + # pre_mlp_layernorm is IdentityOp (no weights); self_attn_bda / mlp_bda + # are callables (no weights). Anything else is unexpected. + raise ValueError( + f"Unexpected sub-key '{rest}' in GPT layer {layer_n} (full key='{key}'). " + "Expected: input_layernorm.*, self_attention.*, mlp.*" + ) + + return mamba_sd + + +# --------------------------------------------------------------------------- +# Forward-pass helpers +# --------------------------------------------------------------------------- + + +def _make_inputs(tokens: torch.Tensor): + """Return position_ids and attention_mask for a token batch.""" + batch_size, seq_len = tokens.shape + position_ids = ( + torch.arange(seq_len, device=tokens.device).unsqueeze(0).expand(batch_size, seq_len) + ) + attention_mask = torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=tokens.device + ) + return position_ids, attention_mask + + +def _forward_logprobs_pp1(model: torch.nn.Module, tokens: torch.Tensor) -> torch.Tensor: + """Single-stage (PP=1) forward returning logprobs [batch, seq-1].""" + position_ids, attention_mask = _make_inputs(tokens) + with torch.no_grad(): + logits = model( + input_ids=tokens, position_ids=position_ids, attention_mask=attention_mask + ) + return selective_log_softmax(logits[:, :-1, :], tokens[:, 1:]) + + +def _forward_logprobs_pp2( + model: torch.nn.Module, tokens: torch.Tensor +) -> Optional[torch.Tensor]: + """Two-stage (PP=2) forward using point-to-point communication. + + Returns logprobs on the last PP stage; None on the first stage. + The caller must invoke this function for both GPT and Mamba models in the + *same order* on all ranks to avoid deadlocks. + """ + batch_size, seq_len = tokens.shape + hidden_size = model.config.hidden_size + position_ids, attention_mask = _make_inputs(tokens) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + next_rank = mpu.get_pipeline_model_parallel_next_rank() + prev_rank = mpu.get_pipeline_model_parallel_prev_rank() + + if pp_rank == 0: + # First stage: embedding + local layers → hidden states [seq, batch, hidden] + with torch.no_grad(): + hidden = model( + input_ids=tokens, position_ids=position_ids, attention_mask=attention_mask + ) + dist.send(hidden.contiguous(), dst=next_rank) + return None + else: + # Last stage: receive hidden states, run remaining layers → logits + hidden_buf = torch.empty( + seq_len, batch_size, hidden_size, + dtype=torch.bfloat16, device=tokens.device, + ) + dist.recv(hidden_buf, src=prev_rank) + model.set_input_tensor(hidden_buf) + with torch.no_grad(): + logits = model( + input_ids=tokens, position_ids=position_ids, attention_mask=attention_mask + ) + return selective_log_softmax(logits[:, :-1, :], tokens[:, 1:]) + + +# --------------------------------------------------------------------------- +# Golden-value recording / comparison helpers +# --------------------------------------------------------------------------- + + +def _save_golden_values(logprobs: torch.Tensor, path: Path) -> None: + """Save GPTModel logprobs to JSON in the functional test format. + + Format:: + + {"0": {"logprobs": [...], "generated_tokens": []}} + + Args: + logprobs: Tensor of shape [batch, seq-1] on CUDA. + path: JSON output path (parent directory must exist). + """ + path.parent.mkdir(parents=True, exist_ok=True) + lp_list = logprobs[0].float().tolist() # first batch item + data = {"0": {"logprobs": lp_list, "generated_tokens": []}} + with open(path, "w") as fh: + json.dump(data, fh, indent=2) + + +def _compare_against_golden_values(logprobs: torch.Tensor, path: Path, abs_tol: float = 1e-3): + """Assert logprobs match the golden values JSON within *abs_tol*.""" + with open(path) as fh: + golden = json.load(fh) + golden_lp = golden["0"]["logprobs"] + actual_lp = logprobs[0].float().tolist() + assert len(actual_lp) == len(golden_lp), ( + f"Logprob length mismatch: actual={len(actual_lp)}, golden={len(golden_lp)}" + ) + for i, (a, g) in enumerate(zip(actual_lp, golden_lp)): + assert math.isclose(a, g, abs_tol=abs_tol), ( + f"Logprob mismatch at position {i}: actual={a:.6f}, golden={g:.6f}" + ) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + +_GOLDEN_BASE = Path(__file__).parent.parent.parent / ( + "functional_tests/test_cases/hybrid" +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("tp,pp", [(1, 1), (2, 1), (1, 2)]) +class TestDSAGPTMambaEquivalence: + """Verify logprob equivalence between GPTModel+DSA and MambaModel+DSA. + + For each distributed configuration (TP, PP), the test: + 1. Builds a GPTModel with 4 DSA layers. + 2. Builds a MambaModel with pattern "S-S-S-S-" (8 layers). + 3. Remaps and loads GPT weights into MambaModel (strict=True). + 4. Runs the same random tokens through both models. + 5. Asserts logprob tensors are numerically close. + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _skip_if_insufficient_gpus(self, tp: int, pp: int) -> None: + world_size = int(os.environ.get('WORLD_SIZE', '1')) + required = tp * pp + if world_size < required: + pytest.skip( + f"Test tp={tp} pp={pp} requires {required} GPU(s), " + f"but WORLD_SIZE={world_size}" + ) + + def test_dsa_logprobs_match(self, tp: int, pp: int) -> None: + """Build both models, transfer weights, compare logprobs.""" + self._skip_if_insufficient_gpus(tp, pp) + Utils.initialize_model_parallel(tp, pp) + model_parallel_cuda_manual_seed(42) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + + # ---- Build GPTModel ---- + gpt_config = _make_dsa_config(num_layers=_NUM_GPT_LAYERS, tp=tp) + gpt_model = _build_gpt_model(gpt_config, pre_process=pre_process, post_process=post_process) + num_local_gpt_layers = len(gpt_model.decoder.layers) + gpt_sd = gpt_model.state_dict() + + # ---- Build MambaModel ---- + mamba_model = _build_mamba_model( + gpt_config, _MAMBA_PATTERN, pre_process=pre_process, post_process=post_process + ) + + # ---- Remap GPT weights → Mamba ---- + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, num_local_gpt_layers) + missing, unexpected = mamba_model.load_state_dict(mamba_sd, strict=True) + assert not missing, f"Missing keys after weight remap: {missing}" + assert not unexpected, f"Unexpected keys after weight remap: {unexpected}" + + # ---- Create identical inputs on all ranks ---- + torch.manual_seed(99) + tokens = torch.randint(0, _VOCAB_SIZE, (_BATCH_SIZE, _SEQ_LEN), device='cuda') + + # ---- Forward pass ---- + if pp == 1: + gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) + # Both models have full logits; compare on all TP ranks + torch.testing.assert_close( + gpt_logprobs, mamba_logprobs, atol=1e-5, rtol=1e-5, + msg=f"Logprob mismatch for tp={tp} pp={pp}", + ) + else: + # PP=2: manual pipeline communication + # Run GPT then Mamba in the same order on all ranks to avoid deadlocks. + gpt_logprobs = _forward_logprobs_pp2(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp2(mamba_model, tokens) + if post_process: + torch.testing.assert_close( + gpt_logprobs, mamba_logprobs, atol=1e-5, rtol=1e-5, + msg=f"Logprob mismatch for tp={tp} pp={pp}", + ) + + def test_weight_loading_strict(self, tp: int, pp: int) -> None: + """Verify that strict=True weight loading succeeds (no missing/unexpected keys).""" + self._skip_if_insufficient_gpus(tp, pp) + Utils.initialize_model_parallel(tp, pp) + model_parallel_cuda_manual_seed(42) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + + gpt_config = _make_dsa_config(num_layers=_NUM_GPT_LAYERS, tp=tp) + gpt_model = _build_gpt_model(gpt_config, pre_process=pre_process, post_process=post_process) + mamba_model = _build_mamba_model( + gpt_config, _MAMBA_PATTERN, pre_process=pre_process, post_process=post_process + ) + + gpt_sd = gpt_model.state_dict() + num_local_gpt_layers = len(gpt_model.decoder.layers) + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, num_local_gpt_layers) + missing, unexpected = mamba_model.load_state_dict(mamba_sd, strict=True) + + assert not missing, f"Missing keys: {missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + def test_record_and_compare_golden_values(self, tp: int, pp: int) -> None: + """Record GPTModel logprobs as golden values, then compare MambaModel against them. + + Golden values are written to the functional test directory so they can be + committed and used by the CI inference golden-value tests (Part 2 of the plan). + """ + self._skip_if_insufficient_gpus(tp, pp) + # Only run for TP=1, PP=1 (the canonical golden-value configuration) + if tp != 1 or pp != 1: + pytest.skip("Golden-value recording only runs for tp=1, pp=1") + + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(42) + + gpt_config = _make_dsa_config(num_layers=_NUM_GPT_LAYERS, tp=1) + gpt_model = _build_gpt_model(gpt_config) + mamba_model = _build_mamba_model(gpt_config, _MAMBA_PATTERN) + + gpt_sd = gpt_model.state_dict() + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, len(gpt_model.decoder.layers)) + mamba_model.load_state_dict(mamba_sd, strict=True) + + torch.manual_seed(99) + tokens = torch.randint(0, _VOCAB_SIZE, (_BATCH_SIZE, _SEQ_LEN), device='cuda') + + gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) + + # Save golden values recorded from GPTModel + golden_dir = _GOLDEN_BASE / "hybrid_dsa_mamba_logitsmatch_tp1_pp1" + golden_path = golden_dir / "golden_values_dev_dgx_h100.json" + _save_golden_values(gpt_logprobs, golden_path) + + # Verify MambaModel matches golden values + _compare_against_golden_values(mamba_logprobs, golden_path, abs_tol=1e-3) diff --git a/tools/checkpoint/remap_gpt_dsa_to_mamba.py b/tools/checkpoint/remap_gpt_dsa_to_mamba.py new file mode 100644 index 00000000000..2e963636cf7 --- /dev/null +++ b/tools/checkpoint/remap_gpt_dsa_to_mamba.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Convert a GPTModel DSA checkpoint to a MambaModel-compatible checkpoint. + +A GPTModel with ``--experimental-attention-variant dsa`` uses one combined +TransformerLayer per model layer (attention + MLP). The equivalent MambaModel +with pattern ``S-S-...`` stores them as two separate layers: + +* Layer 2N – DSA attention (TransformerLayer: input_layernorm + MLASelfAttention) +* Layer 2N+1 – MLP (MLPLayer: fused-norm MLP) + +This script loads a GPTModel Distributed Checkpoint (DCP), remaps the state-dict +keys, and saves a new DCP that can be loaded by MambaModel. + +Usage +----- +:: + + python tools/checkpoint/remap_gpt_dsa_to_mamba.py \\ + --input /path/to/gpt_dsa_dcp_checkpoint \\ + --output /path/to/mamba_dsa_dcp_checkpoint \\ + --num-gpt-layers 4 + +Key remapping rules +------------------- +* ``decoder.layers.{N}.input_layernorm.*`` → ``decoder.layers.{2N}.input_layernorm.*`` +* ``decoder.layers.{N}.self_attention.*`` → ``decoder.layers.{2N}.self_attention.*`` +* ``decoder.layers.{N}.mlp.*`` → ``decoder.layers.{2N+1}.mlp.*`` +* ``decoder.final_layernorm.*`` → ``decoder.final_norm.*`` +* All other keys → unchanged +""" + +import argparse +import os +import re +import shutil +from pathlib import Path +from typing import Dict + + +def _remap_key(key: str, num_gpt_layers: int) -> str: + """Return the MambaModel state-dict key corresponding to *key* from GPTModel. + + Args: + key: A key from the GPTModel state dict. + num_gpt_layers: Total number of GPT decoder layers (across all PP stages). + + Returns: + The remapped key for MambaModel. + + Raises: + ValueError: If an unexpected sub-key is encountered in a decoder layer. + """ + layer_prefix = "decoder.layers." + final_ln_prefix = "decoder.final_layernorm." + + # Final layernorm name differs between TransformerBlock and MambaStack + if key.startswith(final_ln_prefix): + return "decoder.final_norm." + key[len(final_ln_prefix):] + + if not key.startswith(layer_prefix): + return key # embedding, output_layer, rotary_pos_emb, etc. + + # Parse "decoder.layers.{N}.{rest}" + remainder = key[len(layer_prefix):] + dot_idx = remainder.index('.') + layer_n = int(remainder[:dot_idx]) + rest = remainder[dot_idx + 1:] + + if rest.startswith("input_layernorm.") or rest.startswith("self_attention."): + return f"{layer_prefix}{2 * layer_n}.{rest}" + elif rest.startswith("mlp."): + return f"{layer_prefix}{2 * layer_n + 1}.{rest}" + else: + raise ValueError( + f"Unexpected sub-key '{rest}' in GPT layer {layer_n} (full key='{key}'). " + "Expected: input_layernorm.*, self_attention.*, mlp.*" + ) + + +def _remap_state_dict( + gpt_sd: Dict, num_gpt_layers: int +) -> Dict: + """Apply key remapping to the full GPTModel state dict.""" + return {_remap_key(k, num_gpt_layers): v for k, v in gpt_sd.items()} + + +def convert(input_path: Path, output_path: Path, num_gpt_layers: int) -> None: + """Load a GPTModel DCP checkpoint, remap keys, and save as MambaModel DCP. + + Args: + input_path: Path to the GPTModel DCP checkpoint directory. + output_path: Destination directory for the MambaModel DCP checkpoint. + num_gpt_layers: Number of GPT decoder layers in the original model. + """ + try: + import torch + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.format_utils import ( + dcp_to_torch_save, + torch_save_to_dcp, + ) + except ImportError as exc: + raise SystemExit( + "PyTorch distributed checkpoint (torch.distributed.checkpoint) is required. " + "Please upgrade to PyTorch >= 2.0." + ) from exc + + import torch + + print(f"Loading GPTModel checkpoint from: {input_path}") + + # --- Load the flat state dict from DCP --- + # We use dcp_to_torch_save to materialize the DCP into a regular .pt file, + # then remap keys, then convert back to DCP. + tmp_flat = output_path.parent / "_tmp_gpt_flat.pt" + try: + dcp_to_torch_save(str(input_path), str(tmp_flat)) + gpt_sd = torch.load(tmp_flat, map_location="cpu") + print(f"Loaded {len(gpt_sd)} keys from GPTModel checkpoint.") + + # --- Remap keys --- + mamba_sd = _remap_state_dict(gpt_sd, num_gpt_layers) + print( + f"Remapped state dict: {len(gpt_sd)} GPT keys → {len(mamba_sd)} Mamba keys." + ) + + # --- Save remapped state dict as a new flat .pt then convert to DCP --- + tmp_mamba = output_path.parent / "_tmp_mamba_flat.pt" + torch.save(mamba_sd, tmp_mamba) + + output_path.mkdir(parents=True, exist_ok=True) + torch_save_to_dcp(str(tmp_mamba), str(output_path)) + print(f"MambaModel DCP checkpoint saved to: {output_path}") + + finally: + for tmp in (tmp_flat, output_path.parent / "_tmp_mamba_flat.pt"): + if tmp.exists(): + tmp.unlink() + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert GPTModel DSA checkpoint to MambaModel-compatible format." + ) + parser.add_argument( + "--input", required=True, type=Path, + help="Path to the source GPTModel DCP checkpoint directory.", + ) + parser.add_argument( + "--output", required=True, type=Path, + help="Destination path for the MambaModel DCP checkpoint.", + ) + parser.add_argument( + "--num-gpt-layers", required=True, type=int, + help="Number of decoder layers in the GPTModel (e.g. 4).", + ) + args = parser.parse_args() + + if not args.input.exists(): + raise SystemExit(f"Input checkpoint not found: {args.input}") + if args.output.exists(): + print(f"Warning: output path already exists and will be overwritten: {args.output}") + shutil.rmtree(args.output) + + convert(args.input, args.output, args.num_gpt_layers) + print("Conversion complete.") + + +if __name__ == "__main__": + main() From 1e116b0fe977ef53e76eb08145745838d644eeaa Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 19:23:23 +0100 Subject: [PATCH 5/7] Add DSA+MoE GPT/Mamba equivalence tests Extend the DSA GPT/Mamba logprob equivalence suite to cover mixed dense+MoE architectures, mirroring the real DeepSeek-V3 layout where the first N layers are dense and the remaining layers use MoE. Key changes: - Add `pre_mlp_layernorm.*` routing in `_remap_gpt_to_mamba_state_dict` and `_remap_key` (checkpoint tool): MoE layers expose a real TENorm for `pre_mlp_layernorm` (not fused), which maps to MoETransformerLayer 2N+1. Dense layers use IdentityOp and produce no keys, so existing tests are unaffected. - Add `_make_dsa_moe_config` with `moe_layer_freq=[0,0,1,1]` (first 2 GPT layers dense, last 2 MoE) and proxy MoE params matching the DeepSeek-V3 style (4 experts, grouped-gemm, allgather dispatcher, shared experts). - Add `_MOE_MAMBA_PATTERN = "S-S-SESE"` and `TestDSAMoEGPTMambaEquivalence` with the same three parametrized tests as the dense suite (tp=1/2 pp=1/2): logprob match, strict weight loading, and golden-value recording/comparison. - Add functional test configs (`hybrid_dsa_moe_mamba_logitsmatch_tp{1,2}_pp1`) with placeholder golden-value JSONs and corresponding CI recipe entries in `mamba-dsa-static-inference.yaml`. --- .../golden_values_dev_dgx_h100.json | 1 + .../model_config.yaml | 59 ++++++ .../golden_values_dev_dgx_h100.json | 1 + .../model_config.yaml | 59 ++++++ .../h100/mamba-dsa-static-inference.yaml | 10 + .../models/test_dsa_gpt_mamba_equivalence.py | 189 +++++++++++++++++- tools/checkpoint/remap_gpt_dsa_to_mamba.py | 20 +- 7 files changed, 330 insertions(+), 9 deletions(-) create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/model_config.yaml create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json create mode 100644 tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/model_config.yaml diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..701447a7a4c --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/golden_values_dev_dgx_h100.json @@ -0,0 +1 @@ +{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}} diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/model_config.yaml new file mode 100644 index 00000000000..84c031cb3ec --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1/model_config.yaml @@ -0,0 +1,59 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: ":4096:8" +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --use-mcore-models: true + --is-hybrid-model: true + --model-provider: mamba + --disable-bias-linear: true + --position-embedding-type: rope + --multi-latent-attention: true + --q-lora-rank: 64 + --kv-lora-rank: 64 + --qk-head-dim: 64 + --qk-pos-emb-head-dim: 32 + --v-head-dim: 64 + --dsa-indexer-n-heads: 8 + --dsa-indexer-head-dim: 64 + --dsa-indexer-topk: 32 + --num-layers: 8 + --hidden-size: 256 + --num-attention-heads: 16 + --hybrid-override-pattern: "S-S-SESE" + --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --normalization: RMSNorm + --swiglu: true + --bf16: true + --attention-backend: flash + --deterministic-mode: true + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --inference-max-seq-length: 256 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: "The quick brown fox jumps over the lazy dog." + --incoming-requests-per-sec: -1 + # MoE args + --num-experts: 4 + --moe-layer-freq: ([0]*2+[1]*2) + --moe-ffn-hidden-size: 512 + --moe-shared-expert-intermediate-size: 512 + --moe-router-topk: 2 + --moe-grouped-gemm: true + --moe-token-dispatcher-type: allgather + --moe-router-load-balancing-type: aux_loss + --moe-aux-loss-coeff: 0.0 +METRICS: + - "generated_tokens" + - "logprobs" diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..701447a7a4c --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/golden_values_dev_dgx_h100.json @@ -0,0 +1 @@ +{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}} diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/model_config.yaml new file mode 100644 index 00000000000..66f585ca36e --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1/model_config.yaml @@ -0,0 +1,59 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: ":4096:8" +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --use-mcore-models: true + --is-hybrid-model: true + --model-provider: mamba + --disable-bias-linear: true + --position-embedding-type: rope + --multi-latent-attention: true + --q-lora-rank: 64 + --kv-lora-rank: 64 + --qk-head-dim: 64 + --qk-pos-emb-head-dim: 32 + --v-head-dim: 64 + --dsa-indexer-n-heads: 8 + --dsa-indexer-head-dim: 64 + --dsa-indexer-topk: 32 + --num-layers: 8 + --hidden-size: 256 + --num-attention-heads: 16 + --hybrid-override-pattern: "S-S-SESE" + --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --normalization: RMSNorm + --swiglu: true + --bf16: true + --attention-backend: flash + --deterministic-mode: true + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --inference-max-seq-length: 256 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: "The quick brown fox jumps over the lazy dog." + --incoming-requests-per-sec: -1 + # MoE args + --num-experts: 4 + --moe-layer-freq: ([0]*2+[1]*2) + --moe-ffn-hidden-size: 512 + --moe-shared-expert-intermediate-size: 512 + --moe-router-topk: 2 + --moe-grouped-gemm: true + --moe-token-dispatcher-type: allgather + --moe-router-load-balancing-type: aux_loss + --moe-aux-loss-coeff: 0.0 +METRICS: + - "generated_tokens" + - "logprobs" diff --git a/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml b/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml index 64965755ff2..b2cbe19a6c6 100644 --- a/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml +++ b/tests/test_utils/recipes/h100/mamba-dsa-static-inference.yaml @@ -65,3 +65,13 @@ products: - environment: [dev] scope: [mr-broken, mr-github-broken] platforms: [dgx_h100] + - test_case: [hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1] + products: + - environment: [dev] + scope: [mr-broken, mr-github-broken] + platforms: [dgx_h100] + - test_case: [hybrid_dsa_moe_mamba_logitsmatch_tp2_pp1] + products: + - environment: [dev] + scope: [mr-broken, mr-github-broken] + platforms: [dgx_h100] diff --git a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py index 27018f67f39..6f4b486a962 100644 --- a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py +++ b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py @@ -85,6 +85,9 @@ def _patch_hadamard_if_needed(): _NUM_GPT_LAYERS = 4 _MAMBA_PATTERN = "S-S-S-S-" # len=8 = 2 * _NUM_GPT_LAYERS +# MoE variant: first 2 GPT layers dense, last 2 MoE → 8 Mamba layers +_MOE_MAMBA_PATTERN = "S-S-SESE" # 2 dense (S-) + 2 MoE (SE) + # --------------------------------------------------------------------------- # Model construction helpers @@ -120,6 +123,49 @@ def _make_dsa_config(num_layers: int, tp: int = 1) -> MLATransformerConfig: ) +def _make_dsa_moe_config(num_layers: int, tp: int = 1) -> MLATransformerConfig: + """Return a small DeepSeek-V3 proxy MLATransformerConfig with MoE layers. + + Mirrors the DeepSeek-V3 pattern: first 2 GPT layers are dense, last 2 are MoE. + ``moe_layer_freq=[0, 0, 1, 1]`` controls which GPT layers become MoE layers. + """ + return MLATransformerConfig( + num_layers=num_layers, + hidden_size=256, + num_attention_heads=16, + q_lora_rank=64, + kv_lora_rank=64, + qk_head_dim=64, + qk_pos_emb_head_dim=32, + v_head_dim=64, + dsa_indexer_n_heads=8, + dsa_indexer_head_dim=64, + dsa_indexer_topk=32, + normalization="RMSNorm", + bf16=True, + params_dtype=torch.bfloat16, + add_bias_linear=False, + use_cpu_initialization=True, + rope_type='rope', + rotary_base=10000, + rotary_percent=1.0, + experimental_attention_variant="dsa", + hidden_dropout=0.0, + attention_dropout=0.0, + tensor_model_parallel_size=tp, + # MoE fields + num_moe_experts=4, + moe_router_topk=2, + moe_grouped_gemm=True, + moe_token_dispatcher_type="allgather", + moe_router_load_balancing_type="aux_loss", + moe_aux_loss_coeff=0.0, + moe_ffn_hidden_size=512, + moe_shared_expert_intermediate_size=512, + moe_layer_freq=[0, 0, 1, 1], # first 2 layers dense, last 2 MoE + ) + + def _make_pg_collection() -> ProcessGroupCollection: return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) @@ -233,12 +279,17 @@ def _remap_gpt_to_mamba_state_dict( elif rest.startswith("mlp."): # MLP sub-module → MLP layer 2N+1 mamba_sd[f"{layer_prefix}{2 * layer_n + 1}.{rest}"] = value + elif rest.startswith("pre_mlp_layernorm."): + # pre_mlp_layernorm → MoE layer 2N+1 (MoETransformerLayer has TENorm) + # Dense layers use IdentityOp for pre_mlp_layernorm (no state dict keys), + # so this branch only fires for MoE layers. + mamba_sd[f"{layer_prefix}{2 * layer_n + 1}.{rest}"] = value else: - # pre_mlp_layernorm is IdentityOp (no weights); self_attn_bda / mlp_bda - # are callables (no weights). Anything else is unexpected. + # self_attn_bda / mlp_bda are callables (no weights). + # Anything else is unexpected. raise ValueError( f"Unexpected sub-key '{rest}' in GPT layer {layer_n} (full key='{key}'). " - "Expected: input_layernorm.*, self_attention.*, mlp.*" + "Expected: input_layernorm.*, self_attention.*, pre_mlp_layernorm.*, mlp.*" ) return mamba_sd @@ -491,3 +542,135 @@ def test_record_and_compare_golden_values(self, tp: int, pp: int) -> None: # Verify MambaModel matches golden values _compare_against_golden_values(mamba_logprobs, golden_path, abs_tol=1e-3) + + +# --------------------------------------------------------------------------- +# MoE test class +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("tp,pp", [(1, 1), (2, 1), (1, 2)]) +class TestDSAMoEGPTMambaEquivalence: + """Verify logprob equivalence between GPTModel+DSA+MoE and MambaModel+DSA+MoE. + + Architecture: 4 GPT layers with moe_layer_freq=[0,0,1,1] (first 2 dense, last 2 MoE) + maps to 8 Mamba layers with pattern "S-S-SESE": + GPT layer 0 (dense) → Mamba layers 0 (S) + 1 (-) + GPT layer 1 (dense) → Mamba layers 2 (S) + 3 (-) + GPT layer 2 (MoE) → Mamba layers 4 (S) + 5 (E) + GPT layer 3 (MoE) → Mamba layers 6 (S) + 7 (E) + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _skip_if_insufficient_gpus(self, tp: int, pp: int) -> None: + world_size = int(os.environ.get('WORLD_SIZE', '1')) + required = tp * pp + if world_size < required: + pytest.skip( + f"Test tp={tp} pp={pp} requires {required} GPU(s), " + f"but WORLD_SIZE={world_size}" + ) + + def test_dsa_moe_logprobs_match(self, tp: int, pp: int) -> None: + """Build both models with MoE, transfer weights, compare logprobs.""" + self._skip_if_insufficient_gpus(tp, pp) + Utils.initialize_model_parallel(tp, pp) + model_parallel_cuda_manual_seed(42) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + + # ---- Build GPTModel with MoE ---- + gpt_config = _make_dsa_moe_config(num_layers=_NUM_GPT_LAYERS, tp=tp) + gpt_model = _build_gpt_model(gpt_config, pre_process=pre_process, post_process=post_process) + num_local_gpt_layers = len(gpt_model.decoder.layers) + gpt_sd = gpt_model.state_dict() + + # ---- Build MambaModel with MoE pattern ---- + mamba_model = _build_mamba_model( + gpt_config, _MOE_MAMBA_PATTERN, pre_process=pre_process, post_process=post_process + ) + + # ---- Remap GPT weights → Mamba (includes pre_mlp_layernorm.* for MoE layers) ---- + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, num_local_gpt_layers) + missing, unexpected = mamba_model.load_state_dict(mamba_sd, strict=True) + assert not missing, f"Missing keys after weight remap: {missing}" + assert not unexpected, f"Unexpected keys after weight remap: {unexpected}" + + # ---- Create identical inputs on all ranks ---- + torch.manual_seed(99) + tokens = torch.randint(0, _VOCAB_SIZE, (_BATCH_SIZE, _SEQ_LEN), device='cuda') + + # ---- Forward pass ---- + if pp == 1: + gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) + torch.testing.assert_close( + gpt_logprobs, mamba_logprobs, atol=1e-5, rtol=1e-5, + msg=f"MoE logprob mismatch for tp={tp} pp={pp}", + ) + else: + gpt_logprobs = _forward_logprobs_pp2(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp2(mamba_model, tokens) + if post_process: + torch.testing.assert_close( + gpt_logprobs, mamba_logprobs, atol=1e-5, rtol=1e-5, + msg=f"MoE logprob mismatch for tp={tp} pp={pp}", + ) + + def test_moe_weight_loading_strict(self, tp: int, pp: int) -> None: + """Verify that strict=True weight loading succeeds with MoE keys.""" + self._skip_if_insufficient_gpus(tp, pp) + Utils.initialize_model_parallel(tp, pp) + model_parallel_cuda_manual_seed(42) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + + gpt_config = _make_dsa_moe_config(num_layers=_NUM_GPT_LAYERS, tp=tp) + gpt_model = _build_gpt_model(gpt_config, pre_process=pre_process, post_process=post_process) + mamba_model = _build_mamba_model( + gpt_config, _MOE_MAMBA_PATTERN, pre_process=pre_process, post_process=post_process + ) + + gpt_sd = gpt_model.state_dict() + num_local_gpt_layers = len(gpt_model.decoder.layers) + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, num_local_gpt_layers) + missing, unexpected = mamba_model.load_state_dict(mamba_sd, strict=True) + + assert not missing, f"Missing keys: {missing}" + assert not unexpected, f"Unexpected keys: {unexpected}" + + def test_moe_record_and_compare_golden_values(self, tp: int, pp: int) -> None: + """Record GPTModel+MoE logprobs as golden values, then compare MambaModel+MoE.""" + self._skip_if_insufficient_gpus(tp, pp) + if tp != 1 or pp != 1: + pytest.skip("Golden-value recording only runs for tp=1, pp=1") + + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(42) + + gpt_config = _make_dsa_moe_config(num_layers=_NUM_GPT_LAYERS, tp=1) + gpt_model = _build_gpt_model(gpt_config) + mamba_model = _build_mamba_model(gpt_config, _MOE_MAMBA_PATTERN) + + gpt_sd = gpt_model.state_dict() + mamba_sd = _remap_gpt_to_mamba_state_dict(gpt_sd, len(gpt_model.decoder.layers)) + mamba_model.load_state_dict(mamba_sd, strict=True) + + torch.manual_seed(99) + tokens = torch.randint(0, _VOCAB_SIZE, (_BATCH_SIZE, _SEQ_LEN), device='cuda') + + gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) + mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) + + # Save golden values recorded from GPTModel + golden_dir = _GOLDEN_BASE / "hybrid_dsa_moe_mamba_logitsmatch_tp1_pp1" + golden_path = golden_dir / "golden_values_dev_dgx_h100.json" + _save_golden_values(gpt_logprobs, golden_path) + + # Verify MambaModel matches golden values + _compare_against_golden_values(mamba_logprobs, golden_path, abs_tol=1e-3) diff --git a/tools/checkpoint/remap_gpt_dsa_to_mamba.py b/tools/checkpoint/remap_gpt_dsa_to_mamba.py index 2e963636cf7..001e0fb4fa1 100644 --- a/tools/checkpoint/remap_gpt_dsa_to_mamba.py +++ b/tools/checkpoint/remap_gpt_dsa_to_mamba.py @@ -23,11 +23,15 @@ Key remapping rules ------------------- -* ``decoder.layers.{N}.input_layernorm.*`` → ``decoder.layers.{2N}.input_layernorm.*`` -* ``decoder.layers.{N}.self_attention.*`` → ``decoder.layers.{2N}.self_attention.*`` -* ``decoder.layers.{N}.mlp.*`` → ``decoder.layers.{2N+1}.mlp.*`` -* ``decoder.final_layernorm.*`` → ``decoder.final_norm.*`` -* All other keys → unchanged +* ``decoder.layers.{N}.input_layernorm.*`` → ``decoder.layers.{2N}.input_layernorm.*`` +* ``decoder.layers.{N}.self_attention.*`` → ``decoder.layers.{2N}.self_attention.*`` +* ``decoder.layers.{N}.pre_mlp_layernorm.*`` → ``decoder.layers.{2N+1}.pre_mlp_layernorm.*`` +* ``decoder.layers.{N}.mlp.*`` → ``decoder.layers.{2N+1}.mlp.*`` +* ``decoder.final_layernorm.*`` → ``decoder.final_norm.*`` +* All other keys → unchanged + +Note: ``pre_mlp_layernorm`` only appears for MoE layers (where it is a real TENorm). +Dense layers use ``IdentityOp`` for ``pre_mlp_layernorm``, which produces no state dict keys. """ import argparse @@ -71,10 +75,14 @@ def _remap_key(key: str, num_gpt_layers: int) -> str: return f"{layer_prefix}{2 * layer_n}.{rest}" elif rest.startswith("mlp."): return f"{layer_prefix}{2 * layer_n + 1}.{rest}" + elif rest.startswith("pre_mlp_layernorm."): + # MoE layers have a real TENorm for pre_mlp_layernorm (not fused); + # it maps to MoETransformerLayer 2N+1. + return f"{layer_prefix}{2 * layer_n + 1}.{rest}" else: raise ValueError( f"Unexpected sub-key '{rest}' in GPT layer {layer_n} (full key='{key}'). " - "Expected: input_layernorm.*, self_attention.*, mlp.*" + "Expected: input_layernorm.*, self_attention.*, pre_mlp_layernorm.*, mlp.*" ) From 9b599227faf7c3dc4194996d4a3eeefab693f369 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 19:33:29 +0100 Subject: [PATCH 6/7] Do not hardcode FP8 alignment size --- megatron/core/ssm/mamba_mixer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 25c8984fced..f93dbdd7b60 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -34,6 +34,7 @@ make_sharded_tensors_for_checkpoint, sharded_state_dict_default, ) +from megatron.core.fp8_utils import get_fp8_align_size from megatron.core.utils import ( deprecate_inference_params, is_causal_conv1d_min_version, @@ -207,9 +208,10 @@ def __init__( self.nheads = self.d_inner // self.headdim if self.config.fp8: - assert (2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads) % 16 == 0, ( + fp8_align_size = get_fp8_align_size(self.config.fp8_recipe) + assert (2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads) % fp8_align_size == 0, ( "For FP8, the innermost dimension of the Mamba layer " - "input projection output tensor must be a multiple of 16." + f"input projection output tensor must be a multiple of {fp8_align_size}." ) tp_size = self.pg_collection.tp.size() From c85dc0558dd79f557dddfc90461d4188fbebfc88 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 24 Feb 2026 19:35:19 +0100 Subject: [PATCH 7/7] Add blockwise FP8 Mamba tests --- tests/unit_tests/models/test_mamba_model.py | 79 ++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py index 29e3630d7bb..37768434fcd 100644 --- a/tests/unit_tests/models/test_mamba_model.py +++ b/tests/unit_tests/models/test_mamba_model.py @@ -22,7 +22,8 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module -from megatron.core.utils import divide, is_fa_min_version, is_torch_min_version +from megatron.core.utils import divide, is_fa_min_version, is_te_min_version, is_torch_min_version +from megatron.training.utils import get_device_arch_version from tests.unit_tests.test_utilities import Utils @@ -394,3 +395,79 @@ def test_dynamic_inference_padding_with_fp8(self): # Assert that all padding logits are zero. assert torch.all(padding_logits == 0.0), "Logits for padding tokens are not all zero." + + +class TestMambaBlockwiseFP8: + """Tests MambaModel with blockwise FP8.""" + + @torch.inference_mode() + def setup_method(self, method): + fp8_available, reason_for_no_fp8 = check_fp8_support() + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + if not is_te_min_version("2.3.0.dev0"): + pytest.skip("blockwise FP8 requires TransformerEngine >= 2.3.0.dev0") + + if get_device_arch_version() < 9: + pytest.skip("blockwise FP8 requires Hopper architecture (compute capability >= 9.0)") + + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + model_config = TransformerConfig( + num_layers=2, + hidden_size=512, + num_attention_heads=4, + use_cpu_initialization=True, + params_dtype=torch.bfloat16, + bf16=True, + fp8="hybrid", + fp8_recipe="blockwise", + ) + + self.model = MambaModel( + config=model_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=128, + max_sequence_length=16, + hybrid_attention_ratio=0.5, + hybrid_mlp_ratio=0.0, + ) + self.model = Float16Module(self.model.config, self.model) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @torch.inference_mode() + def test_blockwise_fp8_forward(self): + """ + Tests that MambaModel can construct and run a forward pass with blockwise FP8. + """ + self.model.cuda() + self.model.eval() + + sequence_length = 16 + micro_batch_size = 2 + + # Prepare inputs + data = list(range(sequence_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + + # Run forward pass + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + runtime_gather_output=True, + ) + + # Verify output shape + assert logits.shape[0] == micro_batch_size + assert logits.shape[1] == sequence_length + assert logits.shape[2] == self.model.module.vocab_size