Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@ def __init__(
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains layernorms across layers,
# this flag is set to True to skip the input layernorm in
# forward() since it is handled by the previous layer.
self.skip_input_layernorm = False

self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -493,6 +497,8 @@ def forward(

if residual is None:
residual = hidden_states

if not self.skip_input_layernorm:
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
Expand Down Expand Up @@ -668,6 +674,10 @@ def __init__(
quantize_type="nvfp4"
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
and not (differ_pp_stage_with_previous_layer) else None)
# When post_load_weights() chains layernorms across layers,
# this flag is set to True to skip the input layernorm in
# forward() since it is handled by the previous layer.
self.skip_input_layernorm = False

self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -765,6 +775,8 @@ def forward(
) -> Union[torch.Tensor, Fp4QuantizedTensor]:
if residual is None:
residual = hidden_states

if not self.skip_input_layernorm:
hidden_states = self.input_layernorm(hidden_states)

hidden_states = self.self_attn(
Expand Down Expand Up @@ -936,6 +948,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains the final norm into the
# last decoder layer, this flag is set to True to skip
# applying it again in forward().
self.skip_norm = False

def forward(
self,
Expand Down Expand Up @@ -969,6 +985,10 @@ def forward(
lora_params=lora_params,
)

# If self.norm is not handled by the last layer, apply it here.
if not self.skip_norm:
hidden_states = self.norm(hidden_states)

return hidden_states


Expand Down Expand Up @@ -1033,6 +1053,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains the final norm into the
# last decoder layer, this flag is set to True to skip
# applying it again in forward().
self.skip_norm = False

def forward(
self,
Expand Down Expand Up @@ -1065,6 +1089,10 @@ def forward(
lora_params=lora_params,
)

# If self.norm is not handled by the last layer, apply it here.
if not self.skip_norm:
hidden_states = self.norm(hidden_states)

return hidden_states


Expand All @@ -1082,9 +1110,11 @@ def post_load_weights(self):
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
layer.next_layer_layernorm = self.model.norm
self.model.skip_norm = True
else:
layer.next_layer_layernorm = self.model.layers[
idx + 1].input_layernorm
self.model.layers[idx + 1].skip_input_layernorm = True
layer.next_attn = self.model.layers[idx + 1].self_attn


Expand Down Expand Up @@ -1456,9 +1486,11 @@ def post_load_weights(self):
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
layer.next_layer_layernorm = self.model.norm
self.model.skip_norm = True
else:
layer.next_layer_layernorm = self.model.layers[
idx + 1].input_layernorm
self.model.layers[idx + 1].skip_input_layernorm = True
layer.next_attn = self.model.layers[idx + 1].self_attn


Expand Down
1 change: 1 addition & 0 deletions tests/unittest/_torch/modeling/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def test_llama_verification_with_kv_cache_relocation(self) -> None:

llama = LlamaForCausalLM(model_config).to(dtype).to(device)
llama.load_weights(hf_llama.state_dict())

num_blocks = 2
tokens_per_block = 32
head_dim = llama.config.hidden_size // llama.config.num_attention_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
"The transformers between 4.55.0 and 4.56.1 have accuracy "
"issues for Llama4. See: "
"https://github.com/huggingface/transformers/pull/40609")
elif transformers.__version__ >= "4.57.1":
self.skipTest(
"Bumping transformers version to 4.57.1 has accuracy issues for Llama4. See: "
"http://nvbugs/5732958")

torch.random.manual_seed(0)
config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)
# 17B * sizeof(float16) plus some extra for activations
Expand All @@ -301,6 +298,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
weight_mapper.init_model_and_config(llama, model_config)
llama.load_weights(hf_llama.state_dict(),
weight_mapper=weight_mapper)
llama.post_load_weights()

num_blocks = 1
tokens_per_block = 128
Expand Down