From 2a3c372b7d605967ce594ed8902e5c64854c6514 Mon Sep 17 00:00:00 2001 From: Po-Han Huang Date: Sun, 21 Dec 2025 21:58:06 -0800 Subject: [PATCH] [https://nvbugs/5732958][bug] Fix TestLlama4MinLatency::test_llama_allclose_to_hf failure Make Llama/Llama4 forward pass work correctly both with and without post_load_weights() being called, by making the layernorm fusion gracefully degrade: - In post_load_weights(), when moving layernorms between layers, set the source layernorm to None to indicate it has been absorbed. - In DecoderLayer.forward(), if next_layer_layernorm is None (i.e. post_load_weights was not called), fall back to simple residual add instead of raising an error. - In DecoderLayer.forward(), if input_layernorm is still present (not absorbed by previous layer), apply it normally. - In Model.forward(), guard self.norm call since it may be None after being moved to the last decoder layer. - Remove the transformers>=4.57.1 skip in the test, since the root cause (missing post_load_weights) is now fixed. Signed-off-by: Po-Han Huang --- tensorrt_llm/_torch/models/modeling_llama.py | 32 +++++++++++++++++++ .../_torch/modeling/test_modeling_llama.py | 1 + .../test_modeling_llama_min_latency.py | 6 ++-- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 54193a32c08b..743e0b8ef502 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -449,6 +449,10 @@ def __init__( self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -493,6 +497,8 @@ def forward( if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -668,6 +674,10 @@ def __init__( quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 and not (differ_pp_stage_with_previous_layer) else None) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm( hidden_size=config.hidden_size, @@ -765,6 +775,8 @@ def forward( ) -> Union[torch.Tensor, Fp4QuantizedTensor]: if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -936,6 +948,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -969,6 +985,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1033,6 +1053,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -1065,6 +1089,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1082,9 +1110,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn @@ -1456,9 +1486,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index ca503642c670..334e60a61e97 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -407,6 +407,7 @@ def test_llama_verification_with_kv_cache_relocation(self) -> None: llama = LlamaForCausalLM(model_config).to(dtype).to(device) llama.load_weights(hf_llama.state_dict()) + num_blocks = 2 tokens_per_block = 32 head_dim = llama.config.hidden_size // llama.config.num_attention_heads diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 599b1be02119..0ce83559923d 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -271,10 +271,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: "The transformers between 4.55.0 and 4.56.1 have accuracy " "issues for Llama4. See: " "https://github.com/huggingface/transformers/pull/40609") - elif transformers.__version__ >= "4.57.1": - self.skipTest( - "Bumping transformers version to 4.57.1 has accuracy issues for Llama4. See: " - "http://nvbugs/5732958") + torch.random.manual_seed(0) config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG) # 17B * sizeof(float16) plus some extra for activations @@ -301,6 +298,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: weight_mapper.init_model_and_config(llama, model_config) llama.load_weights(hf_llama.state_dict(), weight_mapper=weight_mapper) + llama.post_load_weights() num_blocks = 1 tokens_per_block = 128