From 7471563ea08dacd79e64b93e2606766f9d80f07a Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 17 Apr 2026 02:38:31 -0700 Subject: [PATCH 1/6] fix(moe): plumb mp_policy end-to-end and cast expert weights in forward The MoE code path silently dropped the FSDP2 MixedPrecisionPolicy configured in FSDP2Config: * _shard_ep_fsdp() never forwarded model_wrapper.mp_policy into parallelize_fn, so apply_fsdp() fell back to a hardcoded output_dtype=torch.bfloat16 default that disagreed with NeMo-RL's intended output_dtype=torch.float32. * fully_shard(moe_module.experts, ...) in the ep_shard branch was called without an mp_policy, so ep-sharded expert weights were never cast to the forward activation dtype during all-gather. * For full-EP configs (ep_size == world_size, DP=1), the experts are excluded from FSDP entirely (see ignored_params), so no FSDP wrapper ever carries an mp_policy for them. With fp32-stored expert params (e.g. under fp32 master weights) and bf16 activations propagated by the block's mp_policy, grouped_gemm.ops.gmm / torch._grouped_mm then crash with "Expected b.scalar_type() == torch::kBFloat16". Changes: * _transformers/infrastructure.py: forward model_wrapper.mp_policy into parallelize_fn (overriding the moe_parallelizer mp_policy default when the FSDP2Config policy is present). * moe/parallelizer.py: pass mp_policy to the experts' own fully_shard call so ep_shard-sharded experts honour the forward cast. * moe/experts.py: in GroupedExpertsDeepEP.forward and GroupedExperts.forward, cast expert weights (and biases) to the activation dtype before the grouped GEMM. This is a no-op when weights already match the activation dtype, and rescues the full-EP case where no FSDP wrapper can carry mp_policy. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../_transformers/infrastructure.py | 8 +++-- nemo_automodel/components/moe/experts.py | 34 ++++++++++++++----- nemo_automodel/components/moe/parallelizer.py | 5 +++ 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/nemo_automodel/_transformers/infrastructure.py b/nemo_automodel/_transformers/infrastructure.py index 8376552fd2..df437d3c46 100644 --- a/nemo_automodel/_transformers/infrastructure.py +++ b/nemo_automodel/_transformers/infrastructure.py @@ -299,9 +299,11 @@ def instantiate_infrastructure( if ep_size > 1: from nemo_automodel.components.moe.parallelizer import parallelize_model - parallelize_fn = partial( - parallelize_model, activation_checkpointing=activation_checkpointing, **moe_config.to_dict() - ) + moe_kwargs = moe_config.to_dict() + # Forward mp_policy from distributed config if not explicitly set in MoE config + if moe_kwargs.get("mp_policy") is None and model_wrapper is not None: + moe_kwargs["mp_policy"] = getattr(model_wrapper, "mp_policy", None) + parallelize_fn = partial(parallelize_model, activation_checkpointing=activation_checkpointing, **moe_kwargs) elif autopipeline is not None and model_wrapper is not None: parallelize_fn = partial(parallelize_for_pp, model_wrapper=model_wrapper) diff --git a/nemo_automodel/components/moe/experts.py b/nemo_automodel/components/moe/experts.py index 86e10e6b35..f274569047 100644 --- a/nemo_automodel/components/moe/experts.py +++ b/nemo_automodel/components/moe/experts.py @@ -293,21 +293,31 @@ def forward( f"Number of experts must be divisible by ep_size (ep_size={ep_size})" ) + # Cast expert weights to the activation dtype so that fp32-stored + # parameters (e.g. under fp32 master weights) still work with kernels + # (grouped_gemm / torch._grouped_mm) that require matching dtypes with + # the (typically bf16) activations. When the weights are already in the + # activation dtype these casts are no-ops. + compute_dtype = x.dtype gate_and_up_projs = ( self.gate_and_up_projs.to_local() if isinstance(self.gate_and_up_projs, DTensor) else self.gate_and_up_projs + ).to(compute_dtype) + down_projs = (self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs).to( + compute_dtype ) - down_projs = self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs gate_up_proj_bias = ( ( self.gate_up_proj_bias.to_local() if isinstance(self.gate_up_proj_bias, DTensor) else self.gate_up_proj_bias - ) + ).to(compute_dtype) if self.expert_bias else None ) down_proj_bias = ( - (self.down_proj_bias.to_local() if isinstance(self.down_proj_bias, DTensor) else self.down_proj_bias) + (self.down_proj_bias.to_local() if isinstance(self.down_proj_bias, DTensor) else self.down_proj_bias).to( + compute_dtype + ) if self.expert_bias else None ) @@ -685,8 +695,14 @@ def forward( ) permuted_probs = permuted_probs.unsqueeze(-1) - gate_and_up_projs = self.gate_and_up_projs.to_local() - down_projs = self.down_projs.to_local() + # Cast expert weights to the activation dtype so that fp32-stored + # parameters (e.g. under fp32 master weights) still work with kernels + # (grouped_gemm / torch._grouped_mm) that require matching dtypes with + # the (typically bf16) activations. When the weights are already in the + # activation dtype these casts are no-ops. + compute_dtype = permuted_local_hidden_states.dtype + gate_and_up_projs = self.gate_and_up_projs.to_local().to(compute_dtype) + down_projs = self.down_projs.to_local().to(compute_dtype) if torch.count_nonzero(tokens_per_expert) > 0: if self.use_torch_mm: @@ -700,11 +716,11 @@ def forward( # Apply bias manually after each grouped GEMM via _apply_bias. offs = tokens_per_expert_gpu.cumsum(dim=0).to(torch.int32) output1 = torch._grouped_mm(permuted_local_hidden_states, gate_and_up_projs, offs=offs) - gate_up_proj_bias = self.gate_up_proj_bias.to_local() + gate_up_proj_bias = self.gate_up_proj_bias.to_local().to(compute_dtype) output1 = _apply_bias(output1, gate_up_proj_bias, tokens_per_expert) output1 = self.expert_activation(output1, permuted_probs) output2 = torch._grouped_mm(output1, down_projs, offs=offs) - down_bias = self.down_proj_bias.to_local() + down_bias = self.down_proj_bias.to_local().to(compute_dtype) output2 = _apply_bias(output2, down_bias, tokens_per_expert, permuted_probs) else: output2 = _torch_mm_experts_fwd( @@ -725,14 +741,14 @@ def forward( ) if self.expert_bias: - gate_up_proj_bias = self.gate_up_proj_bias.to_local() + gate_up_proj_bias = self.gate_up_proj_bias.to_local().to(compute_dtype) output1 = _apply_bias(output1, gate_up_proj_bias, tokens_per_expert) output1 = self.expert_activation(output1, permuted_probs) output2 = ops.gmm(output1, down_projs, tokens_per_expert, trans_b=False) if self.expert_bias: - down_bias = self.down_proj_bias.to_local() + down_bias = self.down_proj_bias.to_local().to(compute_dtype) output2 = _apply_bias(output2, down_bias, tokens_per_expert, permuted_probs) else: output1 = torch.matmul(x[0] * 0, gate_and_up_projs[0]) diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index f22fbd79a8..e264bc54e2 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -209,11 +209,16 @@ def apply_fsdp( if isinstance(moe_module, MoE) and ep_shard_enabled: # Apply FSDP on dim=1 for grouped experts since we may have more # shards than experts (dim=0). + # Forward the same mp_policy used elsewhere so that when params are + # kept in fp32 (e.g. for fp32 master weights under FSDP2) the + # all-gathered expert weights are still cast to param_dtype for + # forward compute (required by GMM / TE kernels that expect bf16). fully_shard( moe_module.experts, mesh=ep_shard_mesh, shard_placement_fn=lambda _: Shard(1), reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, ) # If FSDP is disabled for grouped experts because the parameters are already # fully sharded by PP and EP, then we need to explicitly remove the parameters From 48144bd51c28d2cc1d68bd9dabc32dad5b6bfbe2 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 17 Apr 2026 02:38:47 -0700 Subject: [PATCH 2/6] fix(checkpoint): preserve model param dtype across initialize_weights() initialize_model_weights() calls model.initialize_weights() with no args, but several custom models (qwen3_next, qwen3_5_moe, ...) have a signature of initialize_weights(self, buffer_device=None, dtype=torch.bfloat16) that ends with cast_model_to_dtype(self, dtype). That silently re-casts every floating-point parameter back to bfloat16 on the first rank that hits this code path, undoing fp32 initialization (e.g. for fp32 master weights under FSDP2). Infer the target dtype from the existing (floating-point) parameters and pass it through when the model accepts a dtype kwarg, falling back to the no-kwarg call for older signatures. This preserves the dtype chosen at construction time (bf16 by default, fp32 when the user requested fp32 master weights) without requiring every model's initialize_weights to change. The checkpoint-load path is unaffected: DCP copies tensors into the model's existing parameters, so dtypes follow the model rather than the checkpoint. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../components/checkpoint/checkpointing.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index 4a8b7edbdd..d819738bb2 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -558,7 +558,23 @@ def initialize_model_weights( module._is_hf_initialized = False if hasattr(model, "initialize_weights"): - model.initialize_weights() + # Infer the target dtype from existing (floating-point) + # parameters so that a model constructed in fp32 (e.g. for fp32 + # master weights under FSDP2) is not silently cast back to + # bf16 inside model.initialize_weights() -> cast_model_to_dtype(). + param_dtype = None + for p in model.parameters(): + if p.is_floating_point(): + param_dtype = p.dtype + break + try: + if param_dtype is not None: + model.initialize_weights(dtype=param_dtype) + else: + model.initialize_weights() + except TypeError: + # Model's initialize_weights() does not accept a dtype kwarg. + model.initialize_weights() else: logging.warning( "Warning: Model does not have initialize_weights method." From 8428dcc56517207afe83fe0fb6cc3fce552e61b1 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 17 Apr 2026 02:39:15 -0700 Subject: [PATCH 3/6] fix(models): thread config.torch_dtype explicitly in custom MoE models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit initialize_linear_module / initialize_rms_norm_module default dtype=torch.bfloat16, and MoEConfig.dtype defaults to torch.bfloat16. Custom MoE models (qwen3_next, qwen3_5_moe) silently accepted those defaults and ignored config.torch_dtype / torch.get_default_dtype(), so attention projections, MoE experts, MLP, lm_head and RMSNorm were all bf16 even when the user requested torch_dtype=float32 (e.g. for fp32 master weights under FSDP2). Thread the model dtype explicitly from config.torch_dtype at every helper call site: * qwen3_next/layers.py Qwen3NextAttention: pass dtype to q/k/v/o_proj. * qwen3_next/model.py Block: pass dtype to dense MLP; Qwen3NextModel: pass dtype to MoEConfig, embed_tokens, RMSNorm; Qwen3NextForCausalLM: pass dtype to lm_head and state_dict_adapter. * qwen3_5_moe/model.py Qwen3_5MoeTextModelBackend: pass dtype to MoEConfig and embed_tokens. Qwen3_5MoeForConditionalGeneration: pass dtype to lm_head. Also fix a nested-sub-config dtype leak in the VL class. _init_model() overrides only the top-level hf_config.torch_dtype; for VL configs like Qwen3_5MoeConfig the nested text_config / vision_config keep their original (typically bf16) torch_dtype from the checkpoint's config.json. Without the fix below, the text backend would then read text_config.torch_dtype=bf16 while HF-native submodules (e.g. CPAwareGatedDeltaNet, constructed via super().__init__(config) inside local_torch_dtype(fp32)) used get_default_dtype()=fp32, producing a mixed-dtype state that crashed with a generic CuBLAS Error at the second grouped GEMM during validation. Propagate the user-requested dtype to every nested sub-config exposing a torch_dtype attribute, before calling super().__init__(config), so the HF parent, the text backend, and any HF vision / multimodal code that reads sub-config torch_dtype all agree. Rationale for threading instead of flipping helper defaults: keeping initialize_linear_module / MoEConfig bf16 defaults preserves the existing API contract for third-party callers that construct models directly (no local_torch_dtype() wrapper) — PyTorch's global default is fp32, which would otherwise silently make projection modules fp32 while embeddings stayed bf16. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../components/models/qwen3_5_moe/model.py | 29 ++++++++++++++++-- .../components/models/qwen3_next/layers.py | 15 +++++++--- .../components/models/qwen3_next/model.py | 30 ++++++++++++++----- 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/nemo_automodel/components/models/qwen3_5_moe/model.py b/nemo_automodel/components/models/qwen3_5_moe/model.py index 67c31baf2e..abe71a6921 100644 --- a/nemo_automodel/components/models/qwen3_5_moe/model.py +++ b/nemo_automodel/components/models/qwen3_5_moe/model.py @@ -235,6 +235,11 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once; thread explicitly to every sub-module so + # fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # --------------- MoE config --------------- # Qwen3.5-MoE has MoE on every layer, with a shared expert + sigmoid gate. # No ``decoder_sparse_step`` — defaults to 1 so every layer is MoE. @@ -259,14 +264,14 @@ def __init__( softmax_before_topk=True, shared_expert_gate=True, shared_expert_inter_dim=config.shared_expert_intermediate_size, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) # --------------- Layers --------------- - embed_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=embed_dtype) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) # Use Qwen3_5MoeBlock — same as Qwen3Next Block but with native GatedDeltaNet self.layers = nn.ModuleDict( @@ -429,6 +434,20 @@ def __init__( if not _QWEN3_5_MOE_HF_AVAILABLE: raise UnavailableError("transformers.models.qwen3_5_moe is not available.") backend = backend or BackendConfig() + + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent (whose + # vision encoder / multimodal code may read sub-config torch_dtype) and + # our text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + # Initialize HF parent (creates self.model, self.lm_head, vision encoder, etc.) super().__init__(config) @@ -446,7 +465,11 @@ def __init__( # Replace lm_head with NeMo backend linear self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, + text_config.hidden_size, + text_config.vocab_size, + bias=False, + dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16), ) # Expose moe_config for FSDP sync mixin diff --git a/nemo_automodel/components/models/qwen3_next/layers.py b/nemo_automodel/components/models/qwen3_next/layers.py index 8bf95206b8..5f14b8a893 100644 --- a/nemo_automodel/components/models/qwen3_next/layers.py +++ b/nemo_automodel/components/models/qwen3_next/layers.py @@ -28,6 +28,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Qwen3NextRMSNorm(nn.Module): @@ -64,18 +65,24 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, backend: BackendConf self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) self.num_key_value_groups = self.num_heads // self.num_kv_heads + # Thread dtype explicitly from config.torch_dtype so callers that do + # not wrap construction in local_torch_dtype() still get a dtype that + # matches the model's declared dtype (fp32 under fp32 master weights, + # bf16 otherwise). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Query projection outputs 2x size for gating self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim * 2, False + backend.linear, config.hidden_size, self.num_heads * self.head_dim * 2, False, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, False, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, False + backend.linear, self.num_heads * self.head_dim, config.hidden_size, False, dtype=dtype ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/nemo_automodel/components/models/qwen3_next/model.py b/nemo_automodel/components/models/qwen3_next/model.py index 6501f024b5..db868e4cdf 100644 --- a/nemo_automodel/components/models/qwen3_next/model.py +++ b/nemo_automodel/components/models/qwen3_next/model.py @@ -54,7 +54,14 @@ def __init__(self, layer_idx: int, config: Qwen3NextConfig, moe_config: MoEConfi if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + # Thread dtype from config.torch_dtype so the dense MLP dtype stays + # aligned with the rest of the model (fp32 under fp32 master weights). + self.mlp = MLP( + config.hidden_size, + config.intermediate_size, + backend.linear, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), + ) self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -133,6 +140,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype from config.torch_dtype once; thread it + # explicitly into every sub-module so fp32 master weights work even + # when construction is not wrapped in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF Qwen3Next MoE config -> our MoE wrapper moe_defaults = dict( dim=config.hidden_size, @@ -155,18 +167,19 @@ def __init__( softmax_before_topk=True, shared_expert_gate=True, shared_expert_inter_dim=config.shared_expert_intermediate_size, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions self.max_seq_len = config.max_position_embeddings @@ -273,10 +286,13 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Qwen3NextModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Qwen3NextStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): From 067fb1bc72fae5dfafef469c9c72bf25e4c7db82 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 17 Apr 2026 03:09:26 -0700 Subject: [PATCH 4/6] fix(models): thread dtype in remaining text-only custom MoE models Extend the explicit torch_dtype threading pattern established for qwen3_next / qwen3_5_moe to every other text-only custom MoE (and the shared MLA attention) so fp32 master weights work consistently across the full set of custom implementations, regardless of whether the caller wrapped construction in local_torch_dtype(). Pattern applied per model: * Attention layers.py: resolve dtype from config.torch_dtype once and thread it into initialize_linear_module / initialize_rms_norm_module for every q/k/v/o_proj and q_norm/k_norm call. * model.py Block: pass dtype to the dense-path MLP(...) and the per-block RMSNorms. * model.py Model: resolve model_dtype once and thread it into MoEConfig(dict dtype=...), nn.Embedding(dtype=...), and the final initialize_rms_norm_module(dtype=...). * model.py ForCausalLM: thread model_dtype into initialize_linear_module for lm_head and into the state_dict_adapter dtype arg. Models updated: * qwen3_moe * gpt_oss * minimax_m2 * glm4_moe, glm4_moe_lite, glm_moe_dsa * deepseek_v3, deepseek_v32 (both share MLA; v3.2 also has Indexer) * step3p5 * nemotron_v3 (final model norm + attention projections) All changes are purely call-site threading; helper defaults (initialize_linear_module / initialize_rms_norm_module / MoEConfig.dtype) remain bfloat16 so third-party direct construction keeps its previous behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../components/models/deepseek_v3/layers.py | 17 ++++++++-- .../components/models/deepseek_v3/model.py | 34 ++++++++++++++----- .../components/models/deepseek_v32/layers.py | 24 ++++++++++--- .../components/models/deepseek_v32/model.py | 34 ++++++++++++++----- .../components/models/glm4_moe/layers.py | 19 +++++++---- .../components/models/glm4_moe/model.py | 33 +++++++++++++----- .../components/models/glm4_moe_lite/model.py | 33 +++++++++++++----- .../components/models/glm_moe_dsa/model.py | 33 +++++++++++++----- .../components/models/gpt_oss/layers.py | 11 +++--- .../components/models/gpt_oss/model.py | 29 +++++++++++----- .../components/models/minimax_m2/layers.py | 9 +++++ .../components/models/minimax_m2/model.py | 31 +++++++++++------ .../components/models/nemotron_v3/layers.py | 27 ++++++++++++--- .../components/models/nemotron_v3/model.py | 1 + .../components/models/qwen3_moe/layers.py | 18 ++++++---- .../components/models/qwen3_moe/model.py | 33 +++++++++++++----- .../components/models/step3p5/layers.py | 27 ++++++++++----- .../components/models/step3p5/model.py | 7 ++-- 18 files changed, 310 insertions(+), 110 deletions(-) diff --git a/nemo_automodel/components/models/deepseek_v3/layers.py b/nemo_automodel/components/models/deepseek_v3/layers.py index 3aef514046..1ab5a6ab5b 100644 --- a/nemo_automodel/components/models/deepseek_v3/layers.py +++ b/nemo_automodel/components/models/deepseek_v3/layers.py @@ -32,6 +32,7 @@ apply_rotary_emb_qk, yarn_get_mscale, ) +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class MLA(nn.Module): @@ -55,6 +56,7 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): rms_norm_impl = backend.rms_norm hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) if self.q_lora_rank is None: self.q_proj = initialize_linear_module( @@ -62,19 +64,25 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): in_features=hidden_size, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) else: self.q_a_proj = initialize_linear_module( - linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False + linear_impl=linear_impl, + in_features=hidden_size, + out_features=self.q_lora_rank, + bias=False, + dtype=dtype, ) self.q_a_layernorm = initialize_rms_norm_module( - rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps + rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype ) self.q_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.q_lora_rank, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) self.kv_a_proj_with_mqa = initialize_linear_module( @@ -82,21 +90,24 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): in_features=hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, bias=False, + dtype=dtype, ) self.kv_a_layernorm = initialize_rms_norm_module( - rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps + rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype ) self.kv_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.kv_lora_rank, out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.n_heads * self.v_head_dim, out_features=hidden_size, bias=False, + dtype=dtype, ) self.softmax_scale = self.qk_head_dim**-0.5 diff --git a/nemo_automodel/components/models/deepseek_v3/model.py b/nemo_automodel/components/models/deepseek_v3/model.py index ae3c9a0000..e5cf308fbc 100644 --- a/nemo_automodel/components/models/deepseek_v3/model.py +++ b/nemo_automodel/components/models/deepseek_v3/model.py @@ -46,13 +46,20 @@ def __init__( ): super().__init__() self.self_attn = MLA(config, backend) + + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + if layer_idx < config.first_k_dense_replace: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) else: self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -127,6 +134,11 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -142,17 +154,18 @@ def __init__( route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -282,10 +295,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = DeepSeekV3StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/deepseek_v32/layers.py b/nemo_automodel/components/models/deepseek_v32/layers.py index a3e7938e79..8f66b8fac1 100644 --- a/nemo_automodel/components/models/deepseek_v32/layers.py +++ b/nemo_automodel/components/models/deepseek_v32/layers.py @@ -72,6 +72,7 @@ def hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor: yarn_get_mscale, ) from nemo_automodel.components.models.deepseek_v32.config import DeepseekV32Config +from nemo_automodel.shared.utils import dtype_from_str as get_dtype def _rotate_activation(x: torch.Tensor) -> torch.Tensor: @@ -120,6 +121,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): self.backend = backend linear_impl = backend.linear + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) # Project Q from q_lora residual -> num_heads * head_dim self.wq_b = initialize_linear_module( @@ -127,6 +129,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.q_lora_rank, out_features=self.num_heads * self.head_dim, bias=False, + dtype=dtype, ) # Project K from hidden states -> single head_dim (shared across heads) @@ -135,10 +138,11 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.hidden_size, out_features=self.head_dim, bias=False, + dtype=dtype, ) # LayerNorm for K (official uses LayerNorm, not RMSNorm) - self.k_norm = nn.LayerNorm(self.head_dim) + self.k_norm = nn.LayerNorm(self.head_dim, dtype=dtype) # Per-head weight projection from hidden states self.weights_proj = initialize_linear_module( @@ -146,6 +150,7 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=self.hidden_size, out_features=self.num_heads, bias=False, + dtype=dtype, ) def forward( @@ -299,17 +304,23 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): rms_norm_impl = backend.rms_norm hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) # V3.2 always uses q_lora (q_lora_rank is not None) self.q_a_proj = initialize_linear_module( - linear_impl=linear_impl, in_features=hidden_size, out_features=self.q_lora_rank, bias=False + linear_impl=linear_impl, + in_features=hidden_size, + out_features=self.q_lora_rank, + bias=False, + dtype=dtype, ) - self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank) + self.q_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.q_lora_rank, dtype=dtype) self.q_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.q_lora_rank, out_features=self.n_heads * self.qk_head_dim, bias=False, + dtype=dtype, ) self.kv_a_proj_with_mqa = initialize_linear_module( @@ -317,19 +328,24 @@ def __init__(self, config: DeepseekV32Config, backend: BackendConfig): in_features=hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, bias=False, + dtype=dtype, + ) + self.kv_a_layernorm = initialize_rms_norm_module( + rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank, dtype=dtype ) - self.kv_a_layernorm = initialize_rms_norm_module(rms_norm_impl=rms_norm_impl, dim=self.kv_lora_rank) self.kv_b_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.kv_lora_rank, out_features=self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( linear_impl=linear_impl, in_features=self.n_heads * self.v_head_dim, out_features=hidden_size, bias=False, + dtype=dtype, ) self.softmax_scale = self.qk_head_dim**-0.5 diff --git a/nemo_automodel/components/models/deepseek_v32/model.py b/nemo_automodel/components/models/deepseek_v32/model.py index cc4ddecc10..ffb5fe0520 100644 --- a/nemo_automodel/components/models/deepseek_v32/model.py +++ b/nemo_automodel/components/models/deepseek_v32/model.py @@ -59,14 +59,20 @@ def __init__( from nemo_automodel.components.models.common import initialize_rms_norm_module from nemo_automodel.components.moe.layers import MLP, MoE + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + if layer_idx < config.first_k_dense_replace: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) else: self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -92,6 +98,12 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -107,19 +119,20 @@ def __init__( route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): # Use V3.2 Block instead of V3 Block self.layers[str(layer_id)] = DeepseekV32Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -183,11 +196,14 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: # Use V3.2 adapter instead of V3 adapter self.state_dict_adapter = DeepSeekV32StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/glm4_moe/layers.py b/nemo_automodel/components/models/glm4_moe/layers.py index b4ddd5766f..ea6202679d 100644 --- a/nemo_automodel/components/models/glm4_moe/layers.py +++ b/nemo_automodel/components/models/glm4_moe/layers.py @@ -29,6 +29,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Glm4MoeAttention(nn.Module): @@ -58,23 +59,29 @@ def __init__(self, config: Glm4MoeConfig, backend: BackendConfig): self.use_qk_norm = config.use_qk_norm self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, False + backend.linear, self.num_heads * self.head_dim, config.hidden_size, False, dtype=dtype ) # Optional per-head RMSNorm for Q and K if self.use_qk_norm: - self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) - self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.q_norm = initialize_rms_norm_module( + backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype + ) + self.k_norm = initialize_rms_norm_module( + backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype + ) # Attention implementation softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/glm4_moe/model.py b/nemo_automodel/components/models/glm4_moe/model.py index b4a9b6668b..e1b0f0ea43 100644 --- a/nemo_automodel/components/models/glm4_moe/model.py +++ b/nemo_automodel/components/models/glm4_moe/model.py @@ -41,16 +41,22 @@ def __init__(self, layer_idx: int, config: Glm4MoeConfig, moe_config: MoEConfig, super().__init__() self.self_attn = Glm4MoeAttention(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GLM4-MoE uses dense layers for first_k_dense_replace layers, then MoE is_moe_layer = layer_idx >= config.first_k_dense_replace if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -107,6 +113,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF GLM4 MoE config -> our MoE wrapper # GLM4 MoE config fields: # - hidden_size, intermediate_size, moe_intermediate_size @@ -131,18 +142,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, # GLM4 uses sigmoid, not softmax + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions # GLM4 MoE uses partial rotary embeddings @@ -257,10 +269,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Glm4MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/glm4_moe_lite/model.py b/nemo_automodel/components/models/glm4_moe_lite/model.py index b60dafa13c..23d71719d8 100644 --- a/nemo_automodel/components/models/glm4_moe_lite/model.py +++ b/nemo_automodel/components/models/glm4_moe_lite/model.py @@ -42,6 +42,10 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: super().__init__() self.self_attn = MLA(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GLM4-MoE-Lite uses mlp_layer_types to determine dense vs MoE layers mlp_layer_types = getattr(config, "mlp_layer_types", None) if mlp_layer_types is not None: @@ -54,11 +58,13 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -115,6 +121,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map config -> MoE wrapper (same as GLM4 MoE) moe_defaults = dict( dim=config.hidden_size, @@ -135,18 +146,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, # GLM4 uses sigmoid, not softmax + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding for MLA # MLA uses qk_rope_head_dim for the rope dimension @@ -258,10 +270,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Glm4MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/glm_moe_dsa/model.py b/nemo_automodel/components/models/glm_moe_dsa/model.py index cfa8fe38dc..67bb40d188 100644 --- a/nemo_automodel/components/models/glm_moe_dsa/model.py +++ b/nemo_automodel/components/models/glm_moe_dsa/model.py @@ -37,6 +37,10 @@ def __init__(self, layer_idx: int, config: GlmMoeDsaConfig, moe_config: MoEConfi super().__init__() self.self_attn = DeepseekV32MLA(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + mlp_layer_types = getattr(config, "mlp_layer_types", None) if mlp_layer_types is not None: is_moe_layer = mlp_layer_types[layer_idx] == "sparse" @@ -47,11 +51,13 @@ def __init__(self, layer_idx: int, config: GlmMoeDsaConfig, moe_config: MoEConfi if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -108,6 +114,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -127,18 +138,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=False, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings self.qk_rope_head_dim = config.qk_rope_head_dim @@ -246,10 +258,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = GlmMoeDsaStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/gpt_oss/layers.py b/nemo_automodel/components/models/gpt_oss/layers.py index f25c4ee36a..823f1113a0 100644 --- a/nemo_automodel/components/models/gpt_oss/layers.py +++ b/nemo_automodel/components/models/gpt_oss/layers.py @@ -34,6 +34,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class GptOssAttention(nn.Module): @@ -47,17 +48,19 @@ def __init__(self, config: "GptOssConfig", backend: BackendConfig, use_sliding_a self.head_dim = config.head_dim self.hidden_size = config.hidden_size + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, bias=True, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, bias=True + backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, bias=True, dtype=dtype ) self.softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/gpt_oss/model.py b/nemo_automodel/components/models/gpt_oss/model.py index 3a09fb4abc..f30cef75f0 100644 --- a/nemo_automodel/components/models/gpt_oss/model.py +++ b/nemo_automodel/components/models/gpt_oss/model.py @@ -46,9 +46,12 @@ def __init__(self, layer_idx: int, config: GptOssConfig, moe_config: MoEConfig, config, backend, use_sliding_attention=config.layer_types[layer_idx] == "sliding_attention" ) self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) def forward( @@ -102,6 +105,12 @@ def __init__( self.config = config if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped + # in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # GPT-OSS is MoE everywhere; set shared experts to 0 to disable shared path in our MoE wrapper. moe_defaults = dict( dim=config.hidden_size, @@ -123,18 +132,19 @@ def __init__( expert_activation="quick_geglu", activation_alpha=1.702, activation_limit=getattr(config, "swiglu_limit", 7.0), + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cached at model-level (inv_freq + concentration via YaRN/NTK-by-parts) self.max_seq_len = config.max_position_embeddings @@ -242,10 +252,13 @@ def __init__( self.backend = backend or BackendConfig(attn="flex") moe_overrides = kwargs.pop("moe_overrides", None) self.model = GptOssModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = GPTOSSStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/minimax_m2/layers.py b/nemo_automodel/components/models/minimax_m2/layers.py index 07c9733429..ab48474732 100644 --- a/nemo_automodel/components/models/minimax_m2/layers.py +++ b/nemo_automodel/components/models/minimax_m2/layers.py @@ -28,6 +28,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class MiniMaxM2Attention(nn.Module): @@ -42,29 +43,35 @@ def __init__(self, config: Any, backend: BackendConfig): self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // self.num_heads self.use_qk_norm = getattr(config, "use_qk_norm", False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_heads * self.head_dim, bias=False, + dtype=dtype, ) self.k_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, + dtype=dtype, ) self.v_proj = initialize_linear_module( backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, + dtype=dtype, ) self.o_proj = initialize_linear_module( backend.linear, self.num_heads * self.head_dim, config.hidden_size, bias=False, + dtype=dtype, ) # HF MiniMax applies RMSNorm over flattened q/k projection dims before head reshape. @@ -73,11 +80,13 @@ def __init__(self, config: Any, backend: BackendConfig): backend.rms_norm, self.num_heads * self.head_dim, eps=config.rms_norm_eps, + dtype=dtype, ) self.k_norm = initialize_rms_norm_module( backend.rms_norm, self.num_kv_heads * self.head_dim, eps=config.rms_norm_eps, + dtype=dtype, ) else: self.q_norm = None diff --git a/nemo_automodel/components/models/minimax_m2/model.py b/nemo_automodel/components/models/minimax_m2/model.py index 8a66709301..242045a3c2 100644 --- a/nemo_automodel/components/models/minimax_m2/model.py +++ b/nemo_automodel/components/models/minimax_m2/model.py @@ -42,9 +42,12 @@ def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: self.self_attn = MiniMaxM2Attention(config, backend) self.mlp = MoE(moe_config, backend) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -99,6 +102,11 @@ def __init__( score_func = getattr(config, "scoring_func", "sigmoid") score_func = "softmax" if str(score_func).lower() == "softmax" else "sigmoid" + # Resolve model dtype once; thread explicitly to every sub-module so + # fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -119,23 +127,21 @@ def __init__( expert_activation="swiglu", softmax_before_topk=(score_func == "softmax"), force_e_score_correction_bias=True, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, - config.hidden_size, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -261,13 +267,16 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = MiniMaxM2StateDictAdapter( self.config, self.model.moe_config, self.backend, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/nemotron_v3/layers.py b/nemo_automodel/components/models/nemotron_v3/layers.py index ce510babea..0ba76c80e4 100644 --- a/nemo_automodel/components/models/nemotron_v3/layers.py +++ b/nemo_automodel/components/models/nemotron_v3/layers.py @@ -29,6 +29,7 @@ initialize_linear_module, initialize_rms_norm_module, ) +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class NemotronV3Attention(nn.Module): @@ -45,17 +46,34 @@ def __init__(self, config, backend: BackendConfig | None = None): self.attention_bias = getattr(config, "attention_bias", False) self.attention_dropout = getattr(config, "attention_dropout", 0.0) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.q_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_attention_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.k_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.v_proj = initialize_linear_module( - self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias + self.backend.linear, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.attention_bias, + dtype=dtype, ) self.o_proj = initialize_linear_module( - self.backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, self.attention_bias + self.backend.linear, + self.num_attention_heads * self.head_dim, + self.hidden_size, + self.attention_bias, + dtype=dtype, ) softmax_scale = self.head_dim**-0.5 @@ -565,6 +583,7 @@ def __init__(self, config, layer_idx: int, moe_config=None, backend=None): backend.rms_norm, config.hidden_size, eps=config.layer_norm_epsilon, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), ) # Determine layer type from config diff --git a/nemo_automodel/components/models/nemotron_v3/model.py b/nemo_automodel/components/models/nemotron_v3/model.py index d8aa86cf25..24cd0c8529 100644 --- a/nemo_automodel/components/models/nemotron_v3/model.py +++ b/nemo_automodel/components/models/nemotron_v3/model.py @@ -107,6 +107,7 @@ def __init__( self.backend.rms_norm, config.hidden_size, eps=config.layer_norm_epsilon, + dtype=dtype, ) def forward( diff --git a/nemo_automodel/components/models/qwen3_moe/layers.py b/nemo_automodel/components/models/qwen3_moe/layers.py index 249450d6e9..8c3fc131b3 100644 --- a/nemo_automodel/components/models/qwen3_moe/layers.py +++ b/nemo_automodel/components/models/qwen3_moe/layers.py @@ -29,6 +29,7 @@ initialize_rms_norm_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Qwen3MoeAttention(nn.Module): @@ -52,22 +53,27 @@ def __init__(self, config: Qwen3MoeConfig, backend: BackendConfig): attention_bias = getattr(config, "attention_bias", False) + # Thread dtype explicitly from config.torch_dtype so fp32 master + # weights work even when construction is not wrapped in + # local_torch_dtype(). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias, dtype=dtype ) # Per-head RMSNorm - self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) - self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps, dtype=dtype) # Attention implementation softmax_scale = self.head_dim**-0.5 diff --git a/nemo_automodel/components/models/qwen3_moe/model.py b/nemo_automodel/components/models/qwen3_moe/model.py index de4a04b6bc..c1126e6173 100644 --- a/nemo_automodel/components/models/qwen3_moe/model.py +++ b/nemo_automodel/components/models/qwen3_moe/model.py @@ -41,6 +41,10 @@ def __init__(self, layer_idx: int, config: Qwen3MoeConfig, moe_config: MoEConfig super().__init__() self.self_attn = Qwen3MoeAttention(config, backend) + # Thread dtype from config.torch_dtype so the block's own params stay + # aligned with the rest of the model (fp32 under fp32 master weights). + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Qwen3-MoE sparsifies every decoder_sparse_step layer, unless in mlp_only_layers is_moe_layer = ( (layer_idx not in getattr(config, "mlp_only_layers", [])) @@ -50,11 +54,13 @@ def __init__(self, layer_idx: int, config: Qwen3MoeConfig, moe_config: MoEConfig if is_moe_layer: self.mlp = MoE(moe_config, backend) else: - self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype + ) self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=dtype ) self.layer_idx = layer_idx @@ -111,6 +117,11 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") + # Resolve model dtype once from config.torch_dtype and thread it + # explicitly into every sub-module so fp32 master weights work even + # when construction is not wrapped in local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + # Map HF Qwen3 MoE config -> our MoE wrapper # Qwen config fields from example config: # - hidden_size, intermediate_size, moe_intermediate_size, num_experts, num_experts_per_tok, norm_topk_prob @@ -133,18 +144,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) # Rotary embedding cache compatible with our rope_utils functions self.max_seq_len = config.max_position_embeddings @@ -250,10 +262,13 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Qwen3MoeModel(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Qwen3MoeStateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/step3p5/layers.py b/nemo_automodel/components/models/step3p5/layers.py index fa25d40c69..ce06d0320e 100644 --- a/nemo_automodel/components/models/step3p5/layers.py +++ b/nemo_automodel/components/models/step3p5/layers.py @@ -28,6 +28,7 @@ initialize_linear_module, ) from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk +from nemo_automodel.shared.utils import dtype_from_str as get_dtype class Step3p5RMSNorm(nn.Module): @@ -145,9 +146,16 @@ def __init__( self.intermediate_size = intermediate_size or config.intermediate_size self.swiglu_limit = swiglu_limit - self.gate_proj = initialize_linear_module(backend.linear, self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = initialize_linear_module(backend.linear, self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = initialize_linear_module(backend.linear, self.intermediate_size, self.hidden_size, bias=False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.gate_proj = initialize_linear_module( + backend.linear, self.hidden_size, self.intermediate_size, bias=False, dtype=dtype + ) + self.up_proj = initialize_linear_module( + backend.linear, self.hidden_size, self.intermediate_size, bias=False, dtype=dtype + ) + self.down_proj = initialize_linear_module( + backend.linear, self.intermediate_size, self.hidden_size, bias=False, dtype=dtype + ) def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up_proj(x) @@ -198,17 +206,18 @@ def __init__(self, config: Any, layer_idx: int, backend: BackendConfig) -> None: # Projections attention_bias = getattr(config, "attention_bias", False) + dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.q_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias, dtype=dtype ) self.k_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.v_proj = initialize_linear_module( - backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias, dtype=dtype ) self.o_proj = initialize_linear_module( - backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias, dtype=dtype ) # Per-head Q/K normalization using Step3p5RMSNorm @@ -218,7 +227,9 @@ def __init__(self, config: Any, layer_idx: int, backend: BackendConfig) -> None: # Optional head-wise attention gate self.use_head_wise_attn_gate = getattr(config, "use_head_wise_attn_gate", False) if self.use_head_wise_attn_gate: - self.g_proj = initialize_linear_module(backend.linear, config.hidden_size, self.num_heads, bias=False) + self.g_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads, bias=False, dtype=dtype + ) else: self.g_proj = None diff --git a/nemo_automodel/components/models/step3p5/model.py b/nemo_automodel/components/models/step3p5/model.py index e4bb2916d0..e9d4b511c0 100644 --- a/nemo_automodel/components/models/step3p5/model.py +++ b/nemo_automodel/components/models/step3p5/model.py @@ -395,14 +395,17 @@ def __init__( self.backend = backend or BackendConfig() moe_overrides = kwargs.pop("moe_overrides", None) self.model = Step3p5Model(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Step3p5StateDictAdapter( self.config, self.model.moe_config, self.backend, - dtype=get_dtype(getattr(config, "torch_dtype", "bfloat16"), torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): From ce6d2c95c0fb27a8e50c62196e468761f8e11d52 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 17 Apr 2026 03:14:31 -0700 Subject: [PATCH 5/6] fix(models): thread dtype + propagate to sub-configs for VL MoE models Extend the fp32-master-weights fix pattern to the remaining VL MoE models. These all share the same VL-specific gotcha: _init_model() only overrides the top-level hf_config.torch_dtype, but the real params live under nested sub-configs (e.g. config.text_config, config.vision_config, or config.thinker_config.text_config). Those nested torch_dtype values retain the checkpoint's original bf16, leaving a mixed-dtype state when our text backend reads sub_config.torch_dtype while HF-native submodules pick up get_default_dtype() from local_torch_dtype(). The symptom is the same as qwen3_5_moe: a generic CuBLAS Error at the second grouped GEMM during MoE forward. Pattern applied per VL wrapper: * Before super().__init__(config), walk vars(config) and propagate config.torch_dtype to every nested sub-config that exposes a torch_dtype attribute. For Omni's two-level nesting (config.thinker_config.text_config) this walk is performed at both levels. * Text backend __init__: resolve model_dtype once from config.torch_dtype and thread it into MoEConfig(dict), nn.Embedding, and the final initialize_rms_norm_module(dtype=...). * VL ForConditionalGeneration: thread model_dtype into lm_head's initialize_linear_module and the state_dict_adapter dtype arg. Models updated: * qwen3_vl_moe * qwen3_omni_moe (two-level sub-config walk) * gemma4_moe (also removes a dead get_dtype(...) call that was computing a value and throwing it away; the value now drives MoE expert dtype) * mistral4 (top-level Mistral4Model + Mistral4TextModelBackend VL wrapper + Mistral3ForConditionalGeneration sub-config propagation) HF-native submodules (attention, Gemma4MLP, Gemma4RMSNorm, Gemma4TextScaledWordEmbedding, Mistral3MultiModalProjector, vision tower) continue to inherit their dtype from local_torch_dtype() via torch.get_default_dtype(), so no changes are needed there. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../components/models/gemma4_moe/model.py | 22 +++++++++- .../components/models/mistral4/model.py | 42 +++++++++++++++---- .../components/models/qwen3_omni_moe/model.py | 34 +++++++++++++-- .../components/models/qwen3_vl_moe/model.py | 32 +++++++++++--- 4 files changed, 113 insertions(+), 17 deletions(-) diff --git a/nemo_automodel/components/models/gemma4_moe/model.py b/nemo_automodel/components/models/gemma4_moe/model.py index b4dea38e1d..6e9a3520ee 100644 --- a/nemo_automodel/components/models/gemma4_moe/model.py +++ b/nemo_automodel/components/models/gemma4_moe/model.py @@ -259,6 +259,14 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once from config.torch_dtype and thread it into + # MoEConfig so the MoE expert params stay aligned with the rest of the + # model (fp32 under fp32 master weights). HF-native submodules + # (attention, Gemma4MLP, Gemma4RMSNorm, Gemma4TextScaledWordEmbedding) + # inherit their dtype from torch.get_default_dtype() via the + # local_torch_dtype() context established by _init_model(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -277,12 +285,12 @@ def __init__( norm_topk_prob=True, expert_activation="geglu", softmax_before_topk=False, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) self.embed_tokens = Gemma4TextScaledWordEmbedding( config.vocab_size, config.hidden_size, @@ -449,6 +457,18 @@ def __init__( if not getattr(cfg_text, "expert_intermediate_size", None) and getattr(cfg_text, "moe_intermediate_size", None): cfg_text.expert_intermediate_size = cfg_text.moe_intermediate_size + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent and our + # text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + # Initialize the HF parent (creates self.model, self.lm_head, vision tower, etc.) super().__init__(config) diff --git a/nemo_automodel/components/models/mistral4/model.py b/nemo_automodel/components/models/mistral4/model.py index b34e0b3d65..b0317bc1f4 100644 --- a/nemo_automodel/components/models/mistral4/model.py +++ b/nemo_automodel/components/models/mistral4/model.py @@ -174,6 +174,7 @@ def _build_moe_config(config, moe_overrides: dict | None = None) -> MoEConfig: route_scale=config.routed_scaling_factor, aux_loss_coeff=0, norm_topk_prob=config.norm_topk_prob, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), ) if moe_overrides: moe_defaults.update(moe_overrides) @@ -195,13 +196,19 @@ def __init__( if moe_config is not None and moe_overrides is not None: raise ValueError("Cannot pass both moe_config and moe_overrides; use one or the other.") self.moe_config = moe_config or _build_moe_config(config, moe_overrides=moe_overrides) - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) - ) + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=model_dtype) self.layers = torch.nn.ModuleDict() for layer_id in range(config.num_hidden_layers): self.layers[str(layer_id)] = Mistral4Block(layer_id, config, self.moe_config, backend) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.max_seq_len = config.max_position_embeddings rope_theta, rope_scaling, _ = get_rope_config(config) @@ -337,10 +344,13 @@ def __init__( moe_config=moe_config, moe_overrides=moe_overrides, ) - self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + self.lm_head = initialize_linear_module( + self.backend.linear, config.hidden_size, config.vocab_size, bias=False, dtype=model_dtype + ) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = Mistral4StateDictAdapter( - self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + self.config, self.model.moe_config, self.backend, dtype=model_dtype ) def get_input_embeddings(self): @@ -468,7 +478,13 @@ def __init__( self.moe_config = self.model.moe_config # lm_head lives inside language_model (like KimiVLLanguageModelBackend) # so the parallelizer wraps it as part of _model, matching the Kimi pattern. - self.lm_head = initialize_linear_module(backend.linear, config.hidden_size, config.vocab_size, bias=False) + self.lm_head = initialize_linear_module( + backend.linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16), + ) @property def embed_tokens(self): @@ -675,6 +691,18 @@ def __init__( if num_hidden_layers is not None: config.text_config.num_hidden_layers = num_hidden_layers + # _init_model() only overrides the top-level hf_config.torch_dtype; + # for VL configs the nested text_config / vision_config keep their + # original dtype (typically bf16 from the checkpoint's + # config.json). Propagate the user-requested dtype to every nested + # sub-config that exposes a torch_dtype attribute, before building + # the vision tower / text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + self.config = config self.backend = backend text_config = config.text_config diff --git a/nemo_automodel/components/models/qwen3_omni_moe/model.py b/nemo_automodel/components/models/qwen3_omni_moe/model.py index 3766dbf588..5e9374ff0d 100644 --- a/nemo_automodel/components/models/qwen3_omni_moe/model.py +++ b/nemo_automodel/components/models/qwen3_omni_moe/model.py @@ -60,6 +60,12 @@ def __init__( # Map HF Qwen3OmniMoe config -> our MoE wrapper self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -79,16 +85,19 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) self.layers = nn.ModuleList( [Block(layer_id, config, self.moe_config, backend) for layer_id in range(config.num_hidden_layers)] ) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config) def forward( @@ -222,6 +231,22 @@ def __init__( base_config = config.thinker_config if hasattr(config, "thinker_config") else config backend = backend or BackendConfig() + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # Omni configs the real params live under thinker_config.text_config / + # thinker_config.vision_config, whose torch_dtype still holds the + # checkpoint's original value. Propagate the user-requested dtype to + # every nested sub-config that exposes a torch_dtype attribute (both + # the top-level and the thinker sub-config) so the HF parent, our + # text backend, and any HF vision / multimodal code agree. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for parent in (config, base_config): + for sub_cfg in vars(parent).values(): + if sub_cfg is not parent and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + if base_config is not config and getattr(base_config, "torch_dtype", None) != top_dtype: + base_config.torch_dtype = top_dtype + super().__init__(base_config) self.backend = backend @@ -231,8 +256,9 @@ def __init__( self.model = Qwen3OmniMoeThinkerTextModel( text_config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides ) + model_dtype = get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16) self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False, dtype=model_dtype ) self.vocab_size = text_config.vocab_size @@ -250,7 +276,7 @@ def __init__( text_config, self.model.moe_config, self.backend, - dtype=get_dtype(text_config.torch_dtype, torch.bfloat16), + dtype=model_dtype, ) def get_input_embeddings(self): diff --git a/nemo_automodel/components/models/qwen3_vl_moe/model.py b/nemo_automodel/components/models/qwen3_vl_moe/model.py index 3f50516b86..60d7132ade 100644 --- a/nemo_automodel/components/models/qwen3_vl_moe/model.py +++ b/nemo_automodel/components/models/qwen3_vl_moe/model.py @@ -292,6 +292,11 @@ def __init__( self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size + # Resolve model dtype once; thread it explicitly to every sub-module + # so fp32 master weights work even when construction is not wrapped in + # local_torch_dtype(). + model_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) + moe_defaults = dict( dim=config.hidden_size, inter_dim=config.intermediate_size, @@ -311,20 +316,22 @@ def __init__( router_bias=False, expert_activation="swiglu", softmax_before_topk=True, + dtype=model_dtype, ) if moe_overrides: moe_defaults.update(moe_overrides) self.moe_config = moe_config or MoEConfig(**moe_defaults) - embed_dtype = get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=embed_dtype) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=model_dtype) self.layers = nn.ModuleDict( { str(layer_id): Qwen3VLMoeBlock(layer_id, config, self.moe_config, backend) for layer_id in range(config.num_hidden_layers) } ) - self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps, dtype=model_dtype + ) self.rotary_emb = Fp32SafeQwen3VLMoeTextRotaryEmbedding(config=config) def forward( @@ -462,6 +469,20 @@ def __init__( **kwargs, ): backend = backend or BackendConfig() + + # _init_model() only overrides the top-level hf_config.torch_dtype; for + # VL configs the nested text_config / vision_config keep their original + # dtype (typically bf16 from the checkpoint's config.json). Propagate + # the user-requested dtype to every nested sub-config that exposes a + # torch_dtype attribute, before constructing the HF parent (whose + # vision encoder / multimodal code may read sub-config torch_dtype) + # and our text backend. + top_dtype = getattr(config, "torch_dtype", None) + if top_dtype is not None: + for sub_cfg in vars(config).values(): + if sub_cfg is not config and hasattr(sub_cfg, "torch_dtype"): + sub_cfg.torch_dtype = top_dtype + super().__init__(config) self.backend = backend @@ -472,8 +493,9 @@ def __init__( self.model.language_model = Qwen3VLMoeTextModelBackend( text_config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides ) + model_dtype = get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16) self.lm_head = initialize_linear_module( - self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False + self.backend.linear, text_config.hidden_size, text_config.vocab_size, bias=False, dtype=model_dtype ) self.model.moe_config = self.model.language_model.moe_config @@ -486,7 +508,7 @@ def __init__( text_config, self.model.language_model.moe_config, self.backend, - dtype=get_dtype(text_config.torch_dtype, torch.bfloat16), + dtype=model_dtype, ) vision_model = getattr(self.model, "visual") From 7608e159120125e946bb67cd37ff081ed5c8f347 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 20 Apr 2026 02:37:34 -0700 Subject: [PATCH 6/6] fix(shared): dtype_from_str falls back to default for non-str/non-dtype After threading `get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)` into every custom MoE model's __init__, a number of pre-existing unit tests started failing with TypeError: can only concatenate str (not "Mock") to str The tests build their `config` via `unittest.mock.Mock()`, so `getattr(config, "torch_dtype", None)` returns an auto-generated Mock rather than None, which dtype_from_str() then tried to concatenate inside "torch." + val.lower(). Real config objects only ever carry torch.dtype, str, or the attribute simply isn't set. Treat anything else (Mock, or other unexpected types) as "no explicit dtype" and fall back to the supplied default, instead of crashing with a cryptic TypeError deep inside string coercion. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- nemo_automodel/shared/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nemo_automodel/shared/utils.py b/nemo_automodel/shared/utils.py index 021538168f..2f02296440 100644 --- a/nemo_automodel/shared/utils.py +++ b/nemo_automodel/shared/utils.py @@ -29,6 +29,16 @@ def dtype_from_str(val, default=torch.bfloat16): if isinstance(val, torch.dtype): return val + + # val must be a recognized dtype string below; anything else (e.g. a + # Mock attribute auto-created in tests that build configs with + # unittest.mock.Mock) falls back to the default so a call site like + # get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16) stays + # safe instead of crashing deep inside str concatenation. + if not isinstance(val, str): + assert isinstance(default, torch.dtype), default + return default + lut = { "torch.float": torch.float, "torch.float32": torch.float,