Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_dsa_module_spec_for_backend(
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
metainfo={"fuse_input_layernorm": False},
)

return attention
Expand All @@ -138,6 +139,8 @@ def get_experimental_attention_variant_module_spec(

if config.experimental_attention_variant == "gated_delta_net":
return get_gated_delta_net_module_spec(config=config, backend=backend)
elif config.experimental_attention_variant == "dsa":
return get_dsa_module_spec_for_backend(config=config, backend=backend)
else:
raise ValueError(
f"Invalid experimental attention variant: {config.experimental_attention_variant}"
Expand Down
82 changes: 82 additions & 0 deletions megatron/core/models/mamba/mamba_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TELinear,
TENorm,
TERowParallelLinear,
)
Expand All @@ -19,7 +20,18 @@
)
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.experimental_attention_variant.dsa import (
DSAIndexer,
DSAIndexerSubmodules,
DSAttention,
DSAttentionSubmodules,
)
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
MultiTokenPredictionBlockSubmodules,
Expand Down Expand Up @@ -96,6 +108,41 @@
self_attn_bda=get_bias_dropout_add,
),
),
dsa_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TELinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_kv_down_proj=TELinear,
linear_kv_up_proj=TEColumnParallelLinear,
core_attention=ModuleSpec(
module=DSAttention,
submodules=DSAttentionSubmodules(
indexer=ModuleSpec(
module=DSAIndexer,
submodules=DSAIndexerSubmodules(
linear_wq_b=TELinear,
linear_wk=TELinear,
k_norm=TENorm,
linear_weights_proj=TELinear,
),
)
),
),
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
Expand Down Expand Up @@ -156,6 +203,41 @@
self_attn_bda=get_bias_dropout_add,
),
),
dsa_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TELinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_kv_down_proj=TELinear,
linear_kv_up_proj=TEColumnParallelLinear,
core_attention=ModuleSpec(
module=DSAttention,
submodules=DSAttentionSubmodules(
indexer=ModuleSpec(
module=DSAIndexer,
submodules=DSAIndexerSubmodules(
linear_wq_b=TELinear,
linear_wk=TELinear,
k_norm=TENorm,
linear_weights_proj=TELinear,
),
)
),
),
linear_proj=InferenceRowParallelLinear,
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
Expand Down
10 changes: 10 additions & 0 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MambaStackSubmodules:

mamba_layer: Union[ModuleSpec, type] = IdentityOp
attention_layer: Union[ModuleSpec, type] = IdentityOp
dsa_layer: Union[ModuleSpec, type] = IdentityOp
mlp_layer: Union[ModuleSpec, type] = IdentityOp
moe_layer: Union[ModuleSpec, type] = IdentityOp
mtp_block_spec: Optional[ModuleSpec] = None
Expand Down Expand Up @@ -160,6 +161,15 @@ def __init__(
pg_collection=pg_collection,
is_mtp_layer=is_mtp_layer,
)
elif layer_type == LayerSymbols.DSA_ATTENTION:
# DSA attention layers apply their own pp_layer_offset
layer = build_module(
submodules.dsa_layer,
config=self.config,
layer_number=i + 1,
pg_collection=pg_collection,
is_mtp_layer=is_mtp_layer,
)
elif layer_type == LayerSymbols.MLP:
# MLP layers apply their own pp_layer_offset
layer = build_module(
Expand Down
8 changes: 6 additions & 2 deletions megatron/core/ssm/mamba_hybrid_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ class Symbols:

MAMBA = "M"
ATTENTION = "*"
DSA_ATTENTION = "S"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if 'S' should be reserved for sliding-window attention. Wondering if this should be 'D'. Of course, these choices are arbitrary and hopefully ultimately temporary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

MLP = "-"
MOE = 'E'
MTP_SEPARATOR = "/"
VALID = {MAMBA, ATTENTION, MLP, MOE}
VALID = {MAMBA, ATTENTION, DSA_ATTENTION, MLP, MOE}


@dataclass
Expand Down Expand Up @@ -264,6 +265,7 @@ def allocate_layers(
logging.INFO,
f"Hybrid allocation ({Symbols.MAMBA} is mamba, "
f"{Symbols.ATTENTION} is attention, "
f"{Symbols.DSA_ATTENTION} is dsa attention, "
f"{Symbols.MLP} is mlp):",
)
maybe_log_single_rank(logger, logging.INFO, allocation_string)
Expand Down Expand Up @@ -303,7 +305,9 @@ def get_layer_maps_from_layer_type_list(
layer_types = [Symbols.ATTENTION, Symbols.MAMBA, Symbols.MLP, Symbols.MOE]
layer_maps = {layer_type: {} for layer_type in layer_types}
for global_layer_idx, layer_type in enumerate(layer_type_list):
layer_map = layer_maps[layer_type]
# DSA attention layers are treated as attention for KV cache mapping.
effective_type = Symbols.ATTENTION if layer_type == Symbols.DSA_ATTENTION else layer_type
layer_map = layer_maps[effective_type]
local_layer_idx = len(layer_map)
layer_map[global_layer_idx] = local_layer_idx
return [layer_maps[layer_type] for layer_type in layer_types]
Expand Down
6 changes: 4 additions & 2 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.utils import (
deprecate_inference_params,
is_causal_conv1d_min_version,
Expand Down Expand Up @@ -207,9 +208,10 @@ def __init__(
self.nheads = self.d_inner // self.headdim

if self.config.fp8:
assert (2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads) % 16 == 0, (
fp8_align_size = get_fp8_align_size(self.config.fp8_recipe)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What prompts this fix in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just wanted to push it for safety's sake. I'll move the DeepSeek-related changes to another branch. :)
Didn't mean anyone to look at this yet since it's still a draft.

assert (2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads) % fp8_align_size == 0, (
"For FP8, the innermost dimension of the Mamba layer "
"input projection output tensor must be a multiple of 16."
f"input projection output tensor must be a multiple of {fp8_align_size}."
)

tp_size = self.pg_collection.tp.size()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
ENV_VARS:
CUDA_DEVICE_MAX_CONNECTIONS: 1
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
NCCL_ALGO: Ring
CUBLAS_WORKSPACE_CONFIG: ":4096:8"
TEST_TYPE: frozen-start
MODE: inference
MODEL_ARGS:
--load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint
--tokenizer-type: TikTokenizer
--tiktoken-pattern: v2
--transformer-impl: transformer_engine
--tensor-model-parallel-size: 1
--pipeline-model-parallel-size: 1
--use-mcore-models: true
--is-hybrid-model: true
--model-provider: mamba
--disable-bias-linear: true
--position-embedding-type: rope
--multi-latent-attention: true
--q-lora-rank: 64
--kv-lora-rank: 64
--qk-head-dim: 64
--qk-pos-emb-head-dim: 32
--v-head-dim: 64
--dsa-indexer-n-heads: 8
--dsa-indexer-head-dim: 64
--dsa-indexer-topk: 32
--num-layers: 8
--hidden-size: 256
--num-attention-heads: 16
--hybrid-override-pattern: "S-S-S-S-"
--spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec
--normalization: RMSNorm
--swiglu: true
--bf16: true
--attention-backend: flash
--deterministic-mode: true
--temperature: 1.0
--top_k: 1
--return-log-probs: true
--num-tokens-to-generate: 30
--inference-max-seq-length: 256
--output-path: ${INFERENCE_OUTPUT_PATH}
--prompts: "The quick brown fox jumps over the lazy dog."
--incoming-requests-per-sec: -1
METRICS:
- "generated_tokens"
- "logprobs"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
ENV_VARS:
CUDA_DEVICE_MAX_CONNECTIONS: 1
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
NCCL_ALGO: Ring
CUBLAS_WORKSPACE_CONFIG: ":4096:8"
TEST_TYPE: frozen-start
MODE: inference
MODEL_ARGS:
--load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint
--tokenizer-type: TikTokenizer
--tiktoken-pattern: v2
--transformer-impl: transformer_engine
--tensor-model-parallel-size: 2
--pipeline-model-parallel-size: 1
--use-mcore-models: true
--is-hybrid-model: true
--model-provider: mamba
--disable-bias-linear: true
--position-embedding-type: rope
--multi-latent-attention: true
--q-lora-rank: 64
--kv-lora-rank: 64
--qk-head-dim: 64
--qk-pos-emb-head-dim: 32
--v-head-dim: 64
--dsa-indexer-n-heads: 8
--dsa-indexer-head-dim: 64
--dsa-indexer-topk: 32
--num-layers: 8
--hidden-size: 256
--num-attention-heads: 16
--hybrid-override-pattern: "S-S-S-S-"
--spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec
--normalization: RMSNorm
--swiglu: true
--bf16: true
--attention-backend: flash
--deterministic-mode: true
--temperature: 1.0
--top_k: 1
--return-log-probs: true
--num-tokens-to-generate: 30
--inference-max-seq-length: 256
--output-path: ${INFERENCE_OUTPUT_PATH}
--prompts: "The quick brown fox jumps over the lazy dog."
--incoming-requests-per-sec: -1
METRICS:
- "generated_tokens"
- "logprobs"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
ENV_VARS:
CUDA_DEVICE_MAX_CONNECTIONS: 1
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
NCCL_ALGO: Ring
CUBLAS_WORKSPACE_CONFIG: ":4096:8"
TEST_TYPE: frozen-start
MODE: inference
MODEL_ARGS:
--load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_v3_proxy/dcp/checkpoint
--tokenizer-type: TikTokenizer
--tiktoken-pattern: v2
--transformer-impl: transformer_engine
--tensor-model-parallel-size: 1
--pipeline-model-parallel-size: 1
--use-mcore-models: true
--is-hybrid-model: true
--model-provider: mamba
--disable-bias-linear: true
--position-embedding-type: rope
--multi-latent-attention: true
--q-lora-rank: 64
--kv-lora-rank: 64
--qk-head-dim: 64
--qk-pos-emb-head-dim: 32
--v-head-dim: 64
--dsa-indexer-n-heads: 8
--dsa-indexer-head-dim: 64
--dsa-indexer-topk: 32
--num-layers: 8
--hidden-size: 256
--num-attention-heads: 16
--hybrid-override-pattern: "S-S-SESE"
--spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec
--normalization: RMSNorm
--swiglu: true
--bf16: true
--attention-backend: flash
--deterministic-mode: true
--temperature: 1.0
--top_k: 1
--return-log-probs: true
--num-tokens-to-generate: 30
--inference-max-seq-length: 256
--output-path: ${INFERENCE_OUTPUT_PATH}
--prompts: "The quick brown fox jumps over the lazy dog."
--incoming-requests-per-sec: -1
# MoE args
--num-experts: 4
--moe-layer-freq: ([0]*2+[1]*2)
--moe-ffn-hidden-size: 512
--moe-shared-expert-intermediate-size: 512
--moe-router-topk: 2
--moe-grouped-gemm: true
--moe-token-dispatcher-type: allgather
--moe-router-load-balancing-type: aux_loss
--moe-aux-loss-coeff: 0.0
METRICS:
- "generated_tokens"
- "logprobs"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": {"input_prompt": "The quick brown fox jumps over the lazy dog.", "generated_text": "", "generated_tokens": [], "logprobs": []}}
Loading