diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 54193a32c08b..743e0b8ef502 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -449,6 +449,10 @@ def __init__( self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -493,6 +497,8 @@ def forward( if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -668,6 +674,10 @@ def __init__( quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 and not (differ_pp_stage_with_previous_layer) else None) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm( hidden_size=config.hidden_size, @@ -765,6 +775,8 @@ def forward( ) -> Union[torch.Tensor, Fp4QuantizedTensor]: if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -936,6 +948,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -969,6 +985,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1033,6 +1053,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -1065,6 +1089,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1082,9 +1110,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn @@ -1456,9 +1486,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index ca503642c670..334e60a61e97 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -407,6 +407,7 @@ def test_llama_verification_with_kv_cache_relocation(self) -> None: llama = LlamaForCausalLM(model_config).to(dtype).to(device) llama.load_weights(hf_llama.state_dict()) + num_blocks = 2 tokens_per_block = 32 head_dim = llama.config.hidden_size // llama.config.num_attention_heads diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 599b1be02119..0ce83559923d 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -271,10 +271,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: "The transformers between 4.55.0 and 4.56.1 have accuracy " "issues for Llama4. See: " "https://github.com/huggingface/transformers/pull/40609") - elif transformers.__version__ >= "4.57.1": - self.skipTest( - "Bumping transformers version to 4.57.1 has accuracy issues for Llama4. See: " - "http://nvbugs/5732958") + torch.random.manual_seed(0) config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG) # 17B * sizeof(float16) plus some extra for activations @@ -301,6 +298,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: weight_mapper.init_model_and_config(llama, model_config) llama.load_weights(hf_llama.state_dict(), weight_mapper=weight_mapper) + llama.post_load_weights() num_blocks = 1 tokens_per_block = 128