From 8182694972abfc7ee9318b6e5ac1efe100d447df Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 10 Dec 2024 04:12:31 -0800 Subject: [PATCH 1/2] Initial commit Signed-off-by: Piotr Kaminski --- nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index a2d745864d92..f49e34ea5dc7 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -118,7 +118,7 @@ def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional def filter_experts_extra_states(state_dict: dict): - pattern = r'model\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + pattern = r'(model|module)\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' return {k: v for k, v in state_dict.items() if not re.fullmatch(pattern, k)} From 0e0d9a12fab7f7edecab55f1beba75e458db8e59 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Tue, 10 Dec 2024 12:15:20 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index f49e34ea5dc7..f4ace00292f9 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -118,7 +118,9 @@ def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional def filter_experts_extra_states(state_dict: dict): - pattern = r'(model|module)\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + pattern = ( + r'(model|module)\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + ) return {k: v for k, v in state_dict.items() if not re.fullmatch(pattern, k)}