diff --git a/nemo_automodel/_transformers/infrastructure.py b/nemo_automodel/_transformers/infrastructure.py index a5c6ae7917..6fcd3da0e1 100644 --- a/nemo_automodel/_transformers/infrastructure.py +++ b/nemo_automodel/_transformers/infrastructure.py @@ -305,9 +305,11 @@ def instantiate_infrastructure( if ep_size > 1: from nemo_automodel.components.moe.parallelizer import parallelize_model - parallelize_fn = partial( - parallelize_model, activation_checkpointing=activation_checkpointing, **moe_config.to_dict() - ) + moe_kwargs = moe_config.to_dict() + # Forward mp_policy from distributed config if not explicitly set in MoE config + if moe_kwargs.get("mp_policy") is None and model_wrapper is not None: + moe_kwargs["mp_policy"] = getattr(model_wrapper, "mp_policy", None) + parallelize_fn = partial(parallelize_model, activation_checkpointing=activation_checkpointing, **moe_kwargs) elif autopipeline is not None and model_wrapper is not None: parallelize_fn = partial(parallelize_for_pp, model_wrapper=model_wrapper) diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index 5bc3cb8ad4..f24319e94c 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -659,7 +659,23 @@ def initialize_model_weights( module._is_hf_initialized = False if hasattr(model, "initialize_weights"): - model.initialize_weights() + # Infer the target dtype from existing (floating-point) + # parameters so that a model constructed in fp32 (e.g. for fp32 + # master weights under FSDP2) is not silently cast back to + # bf16 inside model.initialize_weights() -> cast_model_to_dtype(). + param_dtype = None + for p in model.parameters(): + if p.is_floating_point(): + param_dtype = p.dtype + break + try: + if param_dtype is not None: + model.initialize_weights(dtype=param_dtype) + else: + model.initialize_weights() + except TypeError: + # Model's initialize_weights() does not accept a dtype kwarg. + model.initialize_weights() else: logging.warning( "Warning: Model does not have initialize_weights method." diff --git a/nemo_automodel/components/models/deepseek_v3/layers.py b/nemo_automodel/components/models/deepseek_v3/layers.py index 3aef514046..1ab5a6ab5b 100644 --- a/nemo_automodel/components/models/deepseek_v3/layers.py +++ b/nemo_automodel/components/models/deepseek_v3/layers.py @@ -32,6 +32,7 @@ apply_rotary_emb_qk, yarn_get_mscale, ) +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class MLA(nn.Module): @@ -55,6 +56,7 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): rms_norm_impl = backend.rms_norm hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) if self.q_lora_rank is None: self.q_proj = initialize_linear_module( @@ -62,19 +64,25 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): in_features=hidden_size, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) else: self.q_a_proj = initialize_linear_module( - linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False + linear_impl=linear_impl, + in_features=hidden_size, + out_features=self.q_lora_rank, + bias=False, + dtype=dtype, ) self.q_a_layernorm = initialize_rms_norm_module( - rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps + rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype ) self.q_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.q_lora_rank, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) self.kv_a_proj_with_mqa = initialize_linear_module( @@ -82,21 +90,24 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): in_features=hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, bias=False, + dtype=dtype, ) self.kv_a_layernorm = initialize_rms_norm_module( - rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps + rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype ) self.kv_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.kv_lora_rank, out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.n_heads * self.v_head_dim, out_features=hidden_size, bias=False, + dtype=dtype, ) self.softmax_scale = self.qk_head_dim**-0.5 diff --git a/nemo_automodel/components/models/deepseek_v3/model.py b/nemo_automodel/components/models/deepseek_v3/model.py index ae3c9a0000..e5cf308fbc 100644 --- a/nemo_automodel/components/models/deepseek_v3/model.py +++ b/nemo_automodel/components/models/deepseek_v3/model.py @@ -46,13 +46,20 @@ def __init__( ): super().__init__() self.self_attn = MLA(config, backend) + + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + if layer_idx < config.first_k_dense_replace: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) else: self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -127,6 +134,11 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -142,17 +154,18 @@ def __init__( route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -282,10 +295,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = DeepSeekV3StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/deepseek_v32/layers.py b/nemo_automodel/components/models/deepseek_v32/layers.py index a3e7938e79..8f66b8fac1 100644 --- a/nemo_automodel/components/models/deepseek_v32/layers.py +++ b/nemo_automodel/components/models/deepseek_v32/layers.py @@ -72,6 +72,7 @@ def hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor: yarn_get_mscale, ) from nemo_automodel.components.models.deepseek_v32.config import DeepseekV32Config +from nemo_automodel.shared.utils import dtype_from_str as get_dtype def _rotate_activation(x: torch.Tensor) -> torch.Tensor: @@ -120,6 +121,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): self.backend = backend linear_impl = backend.linear + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) # Project Q from q_lora residual -> num_heads * head_dim self.wq_b = initialize_linear_module( @@ -127,6 +129,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.q_lora_rank, out_features=self.num_heads * self.head_dim, bias=False, + dtype=dtype, ) # Project K from hidden states -> single head_dim (shared across heads) @@ -135,10 +138,11 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.hidden_size, out_features=self.head_dim, bias=False, + dtype=dtype, ) # LayerNorm for K (official uses LayerNorm, not RMSNorm) - self.k_norm = nn.LayerNorm(self.head_dim) + self.k_norm = nn.LayerNorm(self.head_dim, dtype=dtype) # Per-head weight projection from hidden states self.weights_proj = initialize_linear_module( @@ -146,6 +150,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.hidden_size, out_features=self.num_heads, bias=False, + dtype=dtype, ) def forward( @@ -299,17 +304,23 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): rms_norm_impl = backend.rms_norm hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) # V3.2 always uses q_lora (q_lora_rank is not None) self.q_a_proj = initialize_linear_module( - linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False + linear_impl=linear_impl, + in_features=hidden_size, + out_features=self.q_lora_rank, + bias=False, + dtype=dtype, ) - self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank) + self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, dtype=dtype) self.q_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.q_lora_rank, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) self.kv_a_proj_with_mqa = initialize_linear_module( @@ -317,19 +328,24 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, bias=False, + dtype=dtype, + ) + self.kv_a_layernorm = initialize_rms_norm_module( + rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, dtype=dtype ) - self.kv_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank) self.kv_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.kv_lora_rank, out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.n_heads * self.v_head_dim, out_features=hidden_size, bias=False, + dtype=dtype, ) self.softmax_scale = self.qk_head_dim**-0.5 diff --git a/nemo_automodel/components/models/deepseek_v32/model.py b/nemo_automodel/components/models/deepseek_v32/model.py index cc4ddecc10..ffb5fe0520 100644 --- a/nemo_automodel/components/models/deepseek_v32/model.py +++ b/nemo_automodel/components/models/deepseek_v32/model.py @@ -59,14 +59,20 @@ def __init__( from nemo_automodel.components.models.common import initialize_rms_norm_module from nemo_automodel.components.moe.layers import MLP, MoE + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + if layer_idx < config.first_k_dense_replace: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) else: self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -92,6 +98,12 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -107,19 +119,20 @@ def __init__( route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): # Use V3.2 Block instead of V3 Block self.layers[str(layer_id)] = DeepseekV32Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -183,11 +196,14 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: # Use V3.2 adapter instead of V3 adapter self.state_dict_adapter = DeepSeekV32StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/gemma4_moe/model.py b/nemo_automodel/components/models/gemma4_moe/model.py index b821277392..3f326ff350 100644 --- a/nemo_automodel/components/models/gemma4_moe/model.py +++ b/nemo_automodel/components/models/gemma4_moe/model.py @@ -287,6 +287,14 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once from config.torch_dtype and thread it into + # MoEConfig so the MoE expert params stay aligned with the rest of the + # model (fp32 under fp32 master weights). HF-native submodules + # (attention, Gemma4MLP, Gemma4RMSNorm, Gemma4TextScaledWordEmbedding) + # inherit their dtype from torch.get_default_dtype() via the + # local_torch_dtype() context established by _init_model(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -305,12 +313,12 @@ def __init__( norm_topk_prob=True, expert_activation="geglu", softmax_before_topk=False, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.embed_tokens = Gemma4TextScaledWordEmbedding( config.vocab_size, config.hidden_size, @@ -497,6 +505,18 @@ def __init__( if not getattr(cfg_text, "expert_intermediate_size", None) and getattr(cfg_text, "moe_intermediate_size", None): cfg_text.expert_intermediate_size = cfg_text.moe_intermediate_size + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent and our + # text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + # Initialize the HF parent (creates self.model, self.lm_head, vision tower, etc.) super().__init__(config) diff --git a/nemo_automodel/components/models/glm4_moe/layers.py b/nemo_automodel/components/models/glm4_moe/layers.py index b4ddd5766f..ea6202679d 100644 --- a/nemo_automodel/components/models/glm4_moe/layers.py +++ b/nemo_automodel/components/models/glm4_moe/layers.py @@ -29,6 +29,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Glm4MoeAttention(nn.Module): @@ -58,23 +59,29 @@ def __init__(self, config: Glm4MoeConfig, backend: BackendConfig): self.use_qk_norm = config.use_qk_norm self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, False + backend.linear, self.num_heads * self.head_dim, config.hidden_size, False, dtype=dtype ) # Optional per-head RMSNorm for Q and K if self.use_qk_norm: - self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) - self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.q_norm = initialize_rms_norm_module( + backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype + ) + self.k_norm = initialize_rms_norm_module( + backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype + ) # Attention implementation softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/glm4_moe/model.py b/nemo_automodel/components/models/glm4_moe/model.py index b4a9b6668b..e1b0f0ea43 100644 --- a/nemo_automodel/components/models/glm4_moe/model.py +++ b/nemo_automodel/components/models/glm4_moe/model.py @@ -41,16 +41,22 @@ def __init__(self, layer_idx: int, config: Glm4MoeConfig, moe_config: MoEConfig, super().__init__() self.self_attn = Glm4MoeAttention(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GLM4-MoE uses dense layers for first_k_dense_replace layers, then MoE is_moe_layer = layer_idx >= config.first_k_dense_replace if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -107,6 +113,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF GLM4 MoE config -> our MoE wrapper # GLM4 MoE config fields: # - hidden_size, intermediate_size, moe_intermediate_size @@ -131,18 +142,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, # GLM4 uses sigmoid, not softmax + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions # GLM4 MoE uses partial rotary embeddings @@ -257,10 +269,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Glm4MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/glm4_moe_lite/model.py b/nemo_automodel/components/models/glm4_moe_lite/model.py index b60dafa13c..23d71719d8 100644 --- a/nemo_automodel/components/models/glm4_moe_lite/model.py +++ b/nemo_automodel/components/models/glm4_moe_lite/model.py @@ -42,6 +42,10 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: super().__init__() self.self_attn = MLA(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GLM4-MoE-Lite uses mlp_layer_types to determine dense vs MoE layers mlp_layer_types = getattr(config, "mlp_layer_types", None) if mlp_layer_types is not None: @@ -54,11 +58,13 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -115,6 +121,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map config -> MoE wrapper (same as GLM4 MoE) moe_defaults = dict( dim=config.hidden_size, @@ -135,18 +146,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, # GLM4 uses sigmoid, not softmax + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding for MLA # MLA uses qk_rope_head_dim for the rope dimension @@ -258,10 +270,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Glm4MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/glm_moe_dsa/model.py b/nemo_automodel/components/models/glm_moe_dsa/model.py index cfa8fe38dc..67bb40d188 100644 --- a/nemo_automodel/components/models/glm_moe_dsa/model.py +++ b/nemo_automodel/components/models/glm_moe_dsa/model.py @@ -37,6 +37,10 @@ def __init__(self, layer_idx: int, config: GlmMoeDsaConfig, moe_config: MoEConfi super().__init__() self.self_attn = DeepseekV32MLA(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + mlp_layer_types = getattr(config, "mlp_layer_types", None) if mlp_layer_types is not None: is_moe_layer = mlp_layer_types[layer_idx] == "sparse" @@ -47,11 +51,13 @@ def __init__(self, layer_idx: int, config: GlmMoeDsaConfig, moe_config: MoEConfi if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -108,6 +114,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -127,18 +138,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings self.qk_rope_head_dim = config.qk_rope_head_dim @@ -246,10 +258,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = GlmMoeDsaStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/gpt_oss/layers.py b/nemo_automodel/components/models/gpt_oss/layers.py index f25c4ee36a..823f1113a0 100644 --- a/nemo_automodel/components/models/gpt_oss/layers.py +++ b/nemo_automodel/components/models/gpt_oss/layers.py @@ -34,6 +34,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class GptOssAttention(nn.Module): @@ -47,17 +48,19 @@ def __init__(self, config: "GptOssConfig", backend: BackendConfig, use_sliding_a self.head_dim = config.head_dim self.hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, bias=True, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, bias=True + backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, bias=True, dtype=dtype ) self.softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/gpt_oss/model.py b/nemo_automodel/components/models/gpt_oss/model.py index 3a09fb4abc..f30cef75f0 100644 --- a/nemo_automodel/components/models/gpt_oss/model.py +++ b/nemo_automodel/components/models/gpt_oss/model.py @@ -46,9 +46,12 @@ def __init__(self, layer_idx: int, config: GptOssConfig, moe_config: MoEConfig, config, backend, use_sliding_attention=config.layer_types[layer_idx] == "sliding_attention" ) self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) def forward( @@ -102,6 +105,12 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped + # in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GPT-OSS is MoE everywhere; set shared experts to 0 to disable shared path in our MoE wrapper. moe_defaults = dict( dim=config.hidden_size, @@ -123,18 +132,19 @@ def __init__( expert_activation="quick_geglu", activation_alpha=1.702, activation_limit=getattr(config, "swiglu_limit", 7.0), + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cached at model-level (inv_freq + concentration via YaRN/NTK-by-parts) self.max_seq_len = config.max_position_embeddings @@ -242,10 +252,13 @@ def __init__( self.backend = backend or BackendConfig(attn="flex") moe_overrides = kwargs.pop("moe_overrides", None) self.model = GptOssModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = GPTOSSStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/minimax_m2/layers.py b/nemo_automodel/components/models/minimax_m2/layers.py index 07c9733429..ab48474732 100644 --- a/nemo_automodel/components/models/minimax_m2/layers.py +++ b/nemo_automodel/components/models/minimax_m2/layers.py @@ -28,6 +28,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class MiniMaxM2Attention(nn.Module): @@ -42,29 +43,35 @@ def __init__(self, config: Any, backend: BackendConfig): self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // self.num_heads self.use_qk_norm = getattr(config, "use_qk_norm", False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_heads * self.head_dim, bias=False, + dtype=dtype, ) self.k_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, + dtype=dtype, ) self.v_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( backend.linear, self.num_heads * self.head_dim, config.hidden_size, bias=False, + dtype=dtype, ) # HF MiniMax applies RMSNorm over flattened q/k projection dims before head reshape. @@ -73,11 +80,13 @@ def __init__(self, config: Any, backend: BackendConfig): backend.rms_norm, self.num_heads * self.head_dim, eps=config.rms_norm_eps, + dtype=dtype, ) self.k_norm = initialize_rms_norm_module( backend.rms_norm, self.num_kv_heads * self.head_dim, eps=config.rms_norm_eps, + dtype=dtype, ) else: self.q_norm = None diff --git a/nemo_automodel/components/models/minimax_m2/model.py b/nemo_automodel/components/models/minimax_m2/model.py index 8a66709301..242045a3c2 100644 --- a/nemo_automodel/components/models/minimax_m2/model.py +++ b/nemo_automodel/components/models/minimax_m2/model.py @@ -42,9 +42,12 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: self.self_attn = MiniMaxM2Attention(config, backend) self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -99,6 +102,11 @@ def __init__( score_func = getattr(config, "scoring_func", "sigmoid") score_func = "softmax" if str(score_func).lower() == "softmax" else "sigmoid" + # Resolve model dtype once; thread explicitly to every sub-module so + # fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -119,23 +127,21 @@ def __init__( expert_activation="swiglu", softmax_before_topk=(score_func == "softmax"), force_e_score_correction_bias=True, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, - config.hidden_size, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -261,13 +267,16 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = MiniMaxM2StateDictAdapter( self.config, self.model.moe_config, self.backend, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/mistral4/model.py b/nemo_automodel/components/models/mistral4/model.py index b34e0b3d65..b0317bc1f4 100644 --- a/nemo_automodel/components/models/mistral4/model.py +++ b/nemo_automodel/components/models/mistral4/model.py @@ -174,6 +174,7 @@ def _build_moe_config(config, moe_overrides: dict | None = None) -> MoEConfig: route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), ) if moe_overrides: moe_defaults.update(moe_overrides) @@ -195,13 +196,19 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") self.moe_config = moe_config or _build_moe_config(config, moe_overrides=moe_overrides) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Mistral4Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -337,10 +344,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Mistral4StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): @@ -468,7 +478,13 @@ def __init__( self.moe_config = self.model.moe_config # lm_head lives inside language_model (like KimiVLLanguageModelBackend) # so the parallelizer wraps it as part of _model, matching the Kimi pattern. - self.lm_head = initialize_linear_module(backend.linear, config.hidden_size, config.vocab_size, bias=False) + self.lm_head = initialize_linear_module( + backend.linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), + ) @property def embed_tokens(self): @@ -675,6 +691,18 @@ def __init__( if num_hidden_layers is not None: config.text_config.num_hidden_layers = num_hidden_layers + # _init_model() only overrides the top-level hf_config.torch_dtype; + # for VL configs the nested text_config / vision_config keep their + # original dtype (typically bf16 from the checkpoint's + # config.json). Propagate the user-requested dtype to every nested + # sub-config that exposes a torch_dtype attribute, before building + # the vision tower / text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + self.config = config self.backend = backend text_config = config.text_config diff --git a/nemo_automodel/components/models/nemotron_v3/layers.py b/nemo_automodel/components/models/nemotron_v3/layers.py index ce510babea..0ba76c80e4 100644 --- a/nemo_automodel/components/models/nemotron_v3/layers.py +++ b/nemo_automodel/components/models/nemotron_v3/layers.py @@ -29,6 +29,7 @@ initialize_linear_module, initialize_rms_norm_module, ) +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class NemotronV3Attention(nn.Module): @@ -45,17 +46,34 @@ def __init__(self, config, backend: BackendConfig | None = None): self.attention_bias = getattr(config, "attention_bias", False) self.attention_dropout = getattr(config, "attention_dropout", 0.0) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.q_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_attention_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.k_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.v_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.o_proj = initialize_linear_module( - self.backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, self.attention_bias + self.backend.linear, + self.num_attention_heads * self.head_dim, + self.hidden_size, + self.attention_bias, + dtype=dtype, ) softmax_scale = self.head_dim**-0.5 @@ -565,6 +583,7 @@ def __init__(self, config, layer_idx: int, moe_config=None, backend=None): backend.rms_norm, config.hidden_size, eps=config.layer_norm_epsilon, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), ) # Determine layer type from config diff --git a/nemo_automodel/components/models/nemotron_v3/model.py b/nemo_automodel/components/models/nemotron_v3/model.py index d8aa86cf25..24cd0c8529 100644 --- a/nemo_automodel/components/models/nemotron_v3/model.py +++ b/nemo_automodel/components/models/nemotron_v3/model.py @@ -107,6 +107,7 @@ def __init__( self.backend.rms_norm, config.hidden_size, eps=config.layer_norm_epsilon, + dtype=dtype, ) def forward( diff --git a/nemo_automodel/components/models/qwen3_5_moe/model.py b/nemo_automodel/components/models/qwen3_5_moe/model.py index 67c31baf2e..abe71a6921 100644 --- a/nemo_automodel/components/models/qwen3_5_moe/model.py +++ b/nemo_automodel/components/models/qwen3_5_moe/model.py @@ -235,6 +235,11 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once; thread explicitly to every sub-module so + # fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # --------------- MoE config --------------- # Qwen3.5-MoE has MoE on every layer, with a shared expert + sigmoid gate. # No ``decoder_sparse_step`` — defaults to 1 so every layer is MoE. @@ -259,14 +264,14 @@ def __init__( softmax_before_topk=True, shared_expert_gate=True, shared_expert_inter_dim=config.shared_expert_intermediate_size, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) # --------------- Layers --------------- - embed_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=embed_dtype) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) # Use Qwen3_5MoeBlock — same as Qwen3Next Block but with native GatedDeltaNet self.layers = nn.ModuleDict( @@ -429,6 +434,20 @@ def __init__( if not _QWEN3_5_MOE_HF_AVAILABLE: raise UnavailableError("transformers.models.qwen3_5_moe is not available.") backend = backend or BackendConfig() + + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent (whose + # vision encoder / multimodal code may read sub-config torch_dtype) and + # our text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + # Initialize HF parent (creates self.model, self.lm_head, vision encoder, etc.) super().__init__(config) @@ -446,7 +465,11 @@ def __init__( # Replace lm_head with NeMo backend linear self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, + text_config.hidden_size, + text_config.vocab_size, + bias=False, + dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16), ) # Expose moe_config for FSDP sync mixin diff --git a/nemo_automodel/components/models/qwen3_moe/layers.py b/nemo_automodel/components/models/qwen3_moe/layers.py index 249450d6e9..8c3fc131b3 100644 --- a/nemo_automodel/components/models/qwen3_moe/layers.py +++ b/nemo_automodel/components/models/qwen3_moe/layers.py @@ -29,6 +29,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Qwen3MoeAttention(nn.Module): @@ -52,22 +53,27 @@ def __init__(self, config: Qwen3MoeConfig, backend: BackendConfig): attention_bias = getattr(config, "attention_bias", False) + # Thread dtype explicitly from config.torch_dtype so fp32 master + # weights work even when construction is not wrapped in + # local_torch_dtype(). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias, dtype=dtype ) # Per-head RMSNorm - self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) - self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype) # Attention implementation softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/qwen3_moe/model.py b/nemo_automodel/components/models/qwen3_moe/model.py index de4a04b6bc..c1126e6173 100644 --- a/nemo_automodel/components/models/qwen3_moe/model.py +++ b/nemo_automodel/components/models/qwen3_moe/model.py @@ -41,6 +41,10 @@ def __init__(self, layer_idx: int, config: Qwen3MoeConfig, moe_config: MoEConfig super().__init__() self.self_attn = Qwen3MoeAttention(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Qwen3-MoE sparsifies every decoder_sparse_step layer, unless in mlp_only_layers is_moe_layer = ( (layer_idx not in getattr(config, "mlp_only_layers", [])) @@ -50,11 +54,13 @@ def __init__(self, layer_idx: int, config: Qwen3MoeConfig, moe_config: MoEConfig if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -111,6 +117,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once from config.torch_dtype and thread it + # explicitly into every sub-module so fp32 master weights work even + # when construction is not wrapped in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF Qwen3 MoE config -> our MoE wrapper # Qwen config fields from example config: # - hidden_size, intermediate_size, moe_intermediate_size, num_experts, num_experts_per_tok, norm_topk_prob @@ -133,18 +144,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions self.max_seq_len = config.max_position_embeddings @@ -250,10 +262,13 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Qwen3MoeModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Qwen3MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/qwen3_next/layers.py b/nemo_automodel/components/models/qwen3_next/layers.py index 8bf95206b8..5f14b8a893 100644 --- a/nemo_automodel/components/models/qwen3_next/layers.py +++ b/nemo_automodel/components/models/qwen3_next/layers.py @@ -28,6 +28,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Qwen3NextRMSNorm(nn.Module): @@ -64,18 +65,24 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, backend: BackendConf self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) self.num_key_value_groups = self.num_heads // self.num_kv_heads + # Thread dtype explicitly from config.torch_dtype so callers that do + # not wrap construction in local_torch_dtype() still get a dtype that + # matches the model's declared dtype (fp32 under fp32 master weights, + # bf16 otherwise). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Query projection outputs 2x size for gating self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim * 2, False + backend.linear, config.hidden_size, self.num_heads * self.head_dim * 2, False, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, False + backend.linear, self.num_heads * self.head_dim, config.hidden_size, False, dtype=dtype ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/nemo_automodel/components/models/qwen3_next/model.py b/nemo_automodel/components/models/qwen3_next/model.py index 7c7df49a24..4d3f3cfad6 100644 --- a/nemo_automodel/components/models/qwen3_next/model.py +++ b/nemo_automodel/components/models/qwen3_next/model.py @@ -54,7 +54,14 @@ def __init__(self, layer_idx: int, config: Qwen3NextConfig, moe_config: MoEConfi if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + # Thread dtype from config.torch_dtype so the dense MLP dtype stays + # aligned with the rest of the model (fp32 under fp32 master weights). + self.mlp = MLP( + config.hidden_size, + config.intermediate_size, + backend.linear, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), + ) self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -129,6 +136,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype from config.torch_dtype once; thread it + # explicitly into every sub-module so fp32 master weights work even + # when construction is not wrapped in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF Qwen3Next MoE config -> our MoE wrapper moe_defaults = dict( dim=config.hidden_size, @@ -151,18 +163,19 @@ def __init__( softmax_before_topk=True, shared_expert_gate=True, shared_expert_inter_dim=config.shared_expert_intermediate_size, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions self.max_seq_len = config.max_position_embeddings @@ -269,10 +282,13 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Qwen3NextModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Qwen3NextStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/qwen3_omni_moe/model.py b/nemo_automodel/components/models/qwen3_omni_moe/model.py index 3766dbf588..5e9374ff0d 100644 --- a/nemo_automodel/components/models/qwen3_omni_moe/model.py +++ b/nemo_automodel/components/models/qwen3_omni_moe/model.py @@ -60,6 +60,12 @@ def __init__( # Map HF Qwen3OmniMoe config -> our MoE wrapper self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -79,16 +85,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) self.layers = nn.ModuleList( [Block(layer_id, config, self.moe_config, backend) for layer_id in range(config.num_hidden_layers)] ) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config) def forward( @@ -222,6 +231,22 @@ def __init__( base_config = config.thinker_config if hasattr(config, "thinker_config") else config backend = backend or BackendConfig() + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # Omni configs the real params live under thinker_config.text_config / + # thinker_config.vision_config, whose torch_dtype still holds the + # checkpoint's original value. Propagate the user-requested dtype to + # every nested sub-config that exposes a torch_dtype attribute (both + # the top-level and the thinker sub-config) so the HF parent, our + # text backend, and any HF vision / multimodal code agree. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for parent in (config, base_config): + for sub_cfg in vars(parent).values(): + if sub_cfg is not parent and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + if base_config is not config and getattr(base_config, "torch_dtype", None) != top_dtype: + base_config.torch_dtype = top_dtype + super().__init__(base_config) self.backend = backend @@ -231,8 +256,9 @@ def __init__( self.model = Qwen3OmniMoeThinkerTextModel( text_config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides ) + model_dtype = get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16) self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False, dtype=model_dtype ) self.vocab_size = text_config.vocab_size @@ -250,7 +276,7 @@ def __init__( text_config, self.model.moe_config, self.backend, - dtype=get_dtype(text_config.torch_dtype, torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/qwen3_vl_moe/model.py b/nemo_automodel/components/models/qwen3_vl_moe/model.py index 3f50516b86..60d7132ade 100644 --- a/nemo_automodel/components/models/qwen3_vl_moe/model.py +++ b/nemo_automodel/components/models/qwen3_vl_moe/model.py @@ -292,6 +292,11 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -311,20 +316,22 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - embed_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=embed_dtype) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) self.layers = nn.ModuleDict( { str(layer_id): Qwen3VLMoeBlock(layer_id, config, self.moe_config, backend) for layer_id in range(config.num_hidden_layers) } ) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.rotary_emb = Fp32SafeQwen3VLMoeTextRotaryEmbedding(config=config) def forward( @@ -462,6 +469,20 @@ def __init__( **kwargs, ): backend = backend or BackendConfig() + + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent (whose + # vision encoder / multimodal code may read sub-config torch_dtype) + # and our text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + super().__init__(config) self.backend = backend @@ -472,8 +493,9 @@ def __init__( self.model.language_model = Qwen3VLMoeTextModelBackend( text_config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides ) + model_dtype = get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16) self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False, dtype=model_dtype ) self.model.moe_config = self.model.language_model.moe_config @@ -486,7 +508,7 @@ def __init__( text_config, self.model.language_model.moe_config, self.backend, - dtype=get_dtype(text_config.torch_dtype, torch.bfloat16), + dtype=model_dtype, ) vision_model = getattr(self.model, "visual") diff --git a/nemo_automodel/components/models/step3p5/layers.py b/nemo_automodel/components/models/step3p5/layers.py index fa25d40c69..ce06d0320e 100644 --- a/nemo_automodel/components/models/step3p5/layers.py +++ b/nemo_automodel/components/models/step3p5/layers.py @@ -28,6 +28,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Step3p5RMSNorm(nn.Module): @@ -145,9 +146,16 @@ def __init__( self.intermediate_size = intermediate_size or config.intermediate_size self.swiglu_limit = swiglu_limit - self.gate_proj = initialize_linear_module(backend.linear, self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = initialize_linear_module(backend.linear, self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = initialize_linear_module(backend.linear, self.intermediate_size, self.hidden_size, bias=False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.gate_proj = initialize_linear_module( + backend.linear, self.hidden_size, self.intermediate_size, bias=False, dtype=dtype + ) + self.up_proj = initialize_linear_module( + backend.linear, self.hidden_size, self.intermediate_size, bias=False, dtype=dtype + ) + self.down_proj = initialize_linear_module( + backend.linear, self.intermediate_size, self.hidden_size, bias=False, dtype=dtype + ) def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up_proj(x) @@ -198,17 +206,18 @@ def __init__(self, config: Any, layer_idx: int, backend: BackendConfig) -> None: # Projections attention_bias = getattr(config, "attention_bias", False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias, dtype=dtype ) # Per-head Q/K normalization using Step3p5RMSNorm @@ -218,7 +227,9 @@ def __init__(self, config: Any, layer_idx: int, backend: BackendConfig) -> None: # Optional head-wise attention gate self.use_head_wise_attn_gate = getattr(config, "use_head_wise_attn_gate", False) if self.use_head_wise_attn_gate: - self.g_proj = initialize_linear_module(backend.linear, config.hidden_size, self.num_heads, bias=False) + self.g_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads, bias=False, dtype=dtype + ) else: self.g_proj = None diff --git a/nemo_automodel/components/models/step3p5/model.py b/nemo_automodel/components/models/step3p5/model.py index e4bb2916d0..e9d4b511c0 100644 --- a/nemo_automodel/components/models/step3p5/model.py +++ b/nemo_automodel/components/models/step3p5/model.py @@ -395,14 +395,17 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Step3p5Model(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Step3p5StateDictAdapter( self.config, self.model.moe_config, self.backend, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/moe/experts.py b/nemo_automodel/components/moe/experts.py index 9d25a07c73..0986510fbf 100644 --- a/nemo_automodel/components/moe/experts.py +++ b/nemo_automodel/components/moe/experts.py @@ -293,21 +293,31 @@ def forward( f"Number of experts must be divisible by ep_size (ep_size={ep_size})" ) + # Cast expert weights to the activation dtype so that fp32-stored + # parameters (e.g. under fp32 master weights) still work with kernels + # (grouped_gemm / torch._grouped_mm) that require matching dtypes with + # the (typically bf16) activations. When the weights are already in the + # activation dtype these casts are no-ops. + compute_dtype = x.dtype gate_and_up_projs = ( self.gate_and_up_projs.to_local() if isinstance(self.gate_and_up_projs, DTensor) else self.gate_and_up_projs + ).to(compute_dtype) + down_projs = (self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs).to( + compute_dtype ) - down_projs = self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs gate_up_proj_bias = ( ( self.gate_up_proj_bias.to_local() if isinstance(self.gate_up_proj_bias, DTensor) else self.gate_up_proj_bias - ) + ).to(compute_dtype) if self.expert_bias else None ) down_proj_bias = ( - (self.down_proj_bias.to_local() if isinstance(self.down_proj_bias, DTensor) else self.down_proj_bias) + (self.down_proj_bias.to_local() if isinstance(self.down_proj_bias, DTensor) else self.down_proj_bias).to( + compute_dtype + ) if self.expert_bias else None ) @@ -718,8 +728,14 @@ def forward( ) permuted_probs = permuted_probs.unsqueeze(-1) - gate_and_up_projs = self.gate_and_up_projs.to_local() - down_projs = self.down_projs.to_local() + # Cast expert weights to the activation dtype so that fp32-stored + # parameters (e.g. under fp32 master weights) still work with kernels + # (grouped_gemm / torch._grouped_mm) that require matching dtypes with + # the (typically bf16) activations. When the weights are already in the + # activation dtype these casts are no-ops. + compute_dtype = permuted_local_hidden_states.dtype + gate_and_up_projs = self.gate_and_up_projs.to_local().to(compute_dtype) + down_projs = self.down_projs.to_local().to(compute_dtype) # Match activation dtype for grouped_mm; see GroupedExperts.forward. gate_and_up_projs = gate_and_up_projs.to(permuted_local_hidden_states.dtype) @@ -737,11 +753,11 @@ def forward( # Apply bias manually after each grouped GEMM via _apply_bias. offs = tokens_per_expert_gpu.cumsum(dim=0).to(torch.int32) output1 = torch._grouped_mm(permuted_local_hidden_states, gate_and_up_projs, offs=offs) - gate_up_proj_bias = self.gate_up_proj_bias.to_local() + gate_up_proj_bias = self.gate_up_proj_bias.to_local().to(compute_dtype) output1 = _apply_bias(output1, gate_up_proj_bias, tokens_per_expert) output1 = self.expert_activation(output1, permuted_probs) output2 = torch._grouped_mm(output1, down_projs, offs=offs) - down_bias = self.down_proj_bias.to_local() + down_bias = self.down_proj_bias.to_local().to(compute_dtype) output2 = _apply_bias(output2, down_bias, tokens_per_expert, permuted_probs) else: output2 = _torch_mm_experts_fwd( @@ -762,14 +778,14 @@ def forward( ) if self.expert_bias: - gate_up_proj_bias = self.gate_up_proj_bias.to_local() + gate_up_proj_bias = self.gate_up_proj_bias.to_local().to(compute_dtype) output1 = _apply_bias(output1, gate_up_proj_bias, tokens_per_expert) output1 = self.expert_activation(output1, permuted_probs) output2 = ops.gmm(output1, down_projs, tokens_per_expert, trans_b=False) if self.expert_bias: - down_bias = self.down_proj_bias.to_local() + down_bias = self.down_proj_bias.to_local().to(compute_dtype) output2 = _apply_bias(output2, down_bias, tokens_per_expert, permuted_probs) else: output1 = torch.matmul(x[0] * 0, gate_and_up_projs[0]) diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index f22fbd79a8..e264bc54e2 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -209,11 +209,16 @@ def apply_fsdp( if isinstance(moe_module, MoE) and ep_shard_enabled: # Apply FSDP on dim=1 for grouped experts since we may have more # shards than experts (dim=0). + # Forward the same mp_policy used elsewhere so that when params are + # kept in fp32 (e.g. for fp32 master weights under FSDP2) the + # all-gathered expert weights are still cast to param_dtype for + # forward compute (required by GMM / TE kernels that expect bf16). fully_shard( moe_module.experts, mesh=ep_shard_mesh, shard_placement_fn=lambda _: Shard(1), reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, ) # If FSDP is disabled for grouped experts because the parameters are already # fully sharded by PP and EP, then we need to explicitly remove the parameters diff --git a/nemo_automodel/shared/utils.py b/nemo_automodel/shared/utils.py index 021538168f..2f02296440 100644 --- a/nemo_automodel/shared/utils.py +++ b/nemo_automodel/shared/utils.py @@ -29,6 +29,16 @@ def dtype_from_str(val, default=torch.bfloat16): if isinstance(val, torch.dtype): return val + + # val must be a recognized dtype string below; anything else (e.g. a + # Mock attribute auto-created in tests that build configs with + # unittest.mock.Mock) falls back to the default so a call site like + # get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) stays + # safe instead of crashing deep inside str concatenation. + if not isinstance(val, str): + assert isinstance(default, torch.dtype), default + return default + lut = { "torch.float": torch.float, "torch.float32": torch.float,