Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions nemo_automodel/_transformers/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,11 @@ def instantiate_infrastructure(
if ep_size > 1:
from nemo_automodel.components.moe.parallelizer import parallelize_model

parallelize_fn = partial(
parallelize_model, activation_checkpointing=activation_checkpointing, **moe_config.to_dict()
)
moe_kwargs = moe_config.to_dict()
# Forward mp_policy from distributed config if not explicitly set in MoE config
if moe_kwargs.get("mp_policy") is None and model_wrapper is not None:
moe_kwargs["mp_policy"] = getattr(model_wrapper, "mp_policy", None)
parallelize_fn = partial(parallelize_model, activation_checkpointing=activation_checkpointing, **moe_kwargs)
elif autopipeline is not None and model_wrapper is not None:
parallelize_fn = partial(parallelize_for_pp, model_wrapper=model_wrapper)

Expand Down
18 changes: 17 additions & 1 deletion nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,23 @@ def initialize_model_weights(
module._is_hf_initialized = False

if hasattr(model, "initialize_weights"):
model.initialize_weights()
# Infer the target dtype from existing (floating-point)
# parameters so that a model constructed in fp32 (e.g. for fp32
# master weights under FSDP2) is not silently cast back to
# bf16 inside model.initialize_weights() -> cast_model_to_dtype().
param_dtype = None
for p in model.parameters():
if p.is_floating_point():
param_dtype = p.dtype
break
try:
if param_dtype is not None:
model.initialize_weights(dtype=param_dtype)
else:
model.initialize_weights()
except TypeError:
# Model's initialize_weights() does not accept a dtype kwarg.
model.initialize_weights()
else:
logging.warning(
"Warning: Model does not have initialize_weights method."
Expand Down
17 changes: 14 additions & 3 deletions nemo_automodel/components/models/deepseek_v3/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
apply_rotary_emb_qk,
yarn_get_mscale,
)
from nemo_automodel.shared.utils import dtype_from_str as get_dtype


class MLA(nn.Module):
Expand All @@ -55,48 +56,58 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig):
rms_norm_impl = backend.rms_norm

hidden_size = config.hidden_size
dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

if self.q_lora_rank is None:
self.q_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=hidden_size,
out_features=self.n_heads * self.qk_head_dim,
bias=False,
dtype=dtype,
)
else:
self.q_a_proj = initialize_linear_module(
linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False
linear_impl=linear_impl,
in_features=hidden_size,
out_features=self.q_lora_rank,
bias=False,
dtype=dtype,
)
self.q_a_layernorm = initialize_rms_norm_module(
rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps
rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype
)
self.q_b_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.q_lora_rank,
out_features=self.n_heads * self.qk_head_dim,
bias=False,
dtype=dtype,
)

self.kv_a_proj_with_mqa = initialize_linear_module(
linear_impl=linear_impl,
in_features=hidden_size,
out_features=self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
dtype=dtype,
)
self.kv_a_layernorm = initialize_rms_norm_module(
rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps
rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype
)
self.kv_b_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.kv_lora_rank,
out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
dtype=dtype,
)
self.o_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.n_heads * self.v_head_dim,
out_features=hidden_size,
bias=False,
dtype=dtype,
)
self.softmax_scale = self.qk_head_dim**-0.5

Expand Down
34 changes: 25 additions & 9 deletions nemo_automodel/components/models/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ def __init__(
):
super().__init__()
self.self_attn = MLA(config, backend)

# Thread dtype from config.torch_dtype so the block's own params stay
# aligned with the rest of the model (fp32 under fp32 master weights).
dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

if layer_idx < config.first_k_dense_replace:
self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear)
self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype)
else:
self.mlp = MoE(moe_config, backend)
self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype
)
self.post_attention_layernorm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype
)
self.layer_idx = layer_idx

Expand Down Expand Up @@ -127,6 +134,11 @@ def __init__(
self.config = config
if moe_config is not None and moe_overrides is not None:
raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.")
# Resolve model dtype once; thread it explicitly to every sub-module
# so fp32 master weights work even when construction is not wrapped in
# local_torch_dtype().
model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

moe_defaults = dict(
dim=config.hidden_size,
inter_dim=config.intermediate_size,
Expand All @@ -142,17 +154,18 @@ def __init__(
route_scale=config.routed_scaling_factor,
aux_loss_coeff=0,
norm_topk_prob=config.norm_topk_prob,
dtype=model_dtype,
)
if moe_overrides:
moe_defaults.update(moe_overrides)
self.moe_config = moe_config or MoEConfig(**moe_defaults)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype)
self.layers = torch.nn.ModuleDict()
for layer_id in range(config.num_hidden_layers):
self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend)
self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)
self.norm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype
)

self.max_seq_len = config.max_position_embeddings
rope_theta, rope_scaling, _ = get_rope_config(config)
Expand Down Expand Up @@ -282,10 +295,13 @@ def __init__(
moe_config=moe_config,
moe_overrides=moe_overrides,
)
self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False)
model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)
self.lm_head = initialize_linear_module(
self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype
)
if self.backend.enable_hf_state_dict_adapter:
self.state_dict_adapter = DeepSeekV3StateDictAdapter(
self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
self.config, self.model.moe_config, self.backend, dtype=model_dtype
)

def get_input_embeddings(self):
Expand Down
24 changes: 20 additions & 4 deletions nemo_automodel/components/models/deepseek_v32/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor:
yarn_get_mscale,
)
from nemo_automodel.components.models.deepseek_v32.config import DeepseekV32Config
from nemo_automodel.shared.utils import dtype_from_str as get_dtype


def _rotate_activation(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -120,13 +121,15 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig):

self.backend = backend
linear_impl = backend.linear
dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

# Project Q from q_lora residual -> num_heads * head_dim
self.wq_b = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.q_lora_rank,
out_features=self.num_heads * self.head_dim,
bias=False,
dtype=dtype,
)

# Project K from hidden states -> single head_dim (shared across heads)
Expand All @@ -135,17 +138,19 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig):
in_features=self.hidden_size,
out_features=self.head_dim,
bias=False,
dtype=dtype,
)

# LayerNorm for K (official uses LayerNorm, not RMSNorm)
self.k_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim, dtype=dtype)

# Per-head weight projection from hidden states
self.weights_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.hidden_size,
out_features=self.num_heads,
bias=False,
dtype=dtype,
)

def forward(
Expand Down Expand Up @@ -299,37 +304,48 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig):
rms_norm_impl = backend.rms_norm

hidden_size = config.hidden_size
dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

# V3.2 always uses q_lora (q_lora_rank is not None)
self.q_a_proj = initialize_linear_module(
linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False
linear_impl=linear_impl,
in_features=hidden_size,
out_features=self.q_lora_rank,
bias=False,
dtype=dtype,
)
self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank)
self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, dtype=dtype)
self.q_b_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.q_lora_rank,
out_features=self.n_heads * self.qk_head_dim,
bias=False,
dtype=dtype,
)

self.kv_a_proj_with_mqa = initialize_linear_module(
linear_impl=linear_impl,
in_features=hidden_size,
out_features=self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
dtype=dtype,
)
self.kv_a_layernorm = initialize_rms_norm_module(
rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, dtype=dtype
)
self.kv_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank)
self.kv_b_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.kv_lora_rank,
out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
dtype=dtype,
)
self.o_proj = initialize_linear_module(
linear_impl=linear_impl,
in_features=self.n_heads * self.v_head_dim,
out_features=hidden_size,
bias=False,
dtype=dtype,
)
self.softmax_scale = self.qk_head_dim**-0.5

Expand Down
34 changes: 25 additions & 9 deletions nemo_automodel/components/models/deepseek_v32/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,20 @@ def __init__(
from nemo_automodel.components.models.common import initialize_rms_norm_module
from nemo_automodel.components.moe.layers import MLP, MoE

# Thread dtype from config.torch_dtype so the block's own params stay
# aligned with the rest of the model (fp32 under fp32 master weights).
dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

if layer_idx < config.first_k_dense_replace:
self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear)
self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype)
else:
self.mlp = MoE(moe_config, backend)

self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype
)
self.post_attention_layernorm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype
)
self.layer_idx = layer_idx

Expand All @@ -92,6 +98,12 @@ def __init__(
self.config = config
if moe_config is not None and moe_overrides is not None:
raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.")

# Resolve model dtype once; thread it explicitly to every sub-module
# so fp32 master weights work even when construction is not wrapped in
# local_torch_dtype().
model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)

moe_defaults = dict(
dim=config.hidden_size,
inter_dim=config.intermediate_size,
Expand All @@ -107,19 +119,20 @@ def __init__(
route_scale=config.routed_scaling_factor,
aux_loss_coeff=0,
norm_topk_prob=config.norm_topk_prob,
dtype=model_dtype,
)
if moe_overrides:
moe_defaults.update(moe_overrides)
self.moe_config = moe_config or MoEConfig(**moe_defaults)

self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype)
self.layers = torch.nn.ModuleDict()
for layer_id in range(config.num_hidden_layers):
# Use V3.2 Block instead of V3 Block
self.layers[str(layer_id)] = DeepseekV32Block(layer_id, config, self.moe_config, backend)
self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)
self.norm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype
)

self.max_seq_len = config.max_position_embeddings
rope_theta, rope_scaling, _ = get_rope_config(config)
Expand Down Expand Up @@ -183,11 +196,14 @@ def __init__(
moe_config=moe_config,
moe_overrides=moe_overrides,
)
self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False)
model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)
self.lm_head = initialize_linear_module(
self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype
)
if self.backend.enable_hf_state_dict_adapter:
# Use V3.2 adapter instead of V3 adapter
self.state_dict_adapter = DeepSeekV32StateDictAdapter(
self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
self.config, self.model.moe_config, self.backend, dtype=model_dtype
)

def get_input_embeddings(self):
Expand Down
Loading
Loading