diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index a2d745864d92..f4ace00292f9 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -118,7 +118,9 @@ def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional def filter_experts_extra_states(state_dict: dict): - pattern = r'model\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + pattern = ( + r'(model|module)\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + ) return {k: v for k, v in state_dict.items() if not re.fullmatch(pattern, k)}