From 90d8427ceda830f9b74825847b4ed434eff1c305 Mon Sep 17 00:00:00 2001 From: flymin <905370712@qq.com> Date: Tue, 20 Aug 2024 11:17:26 +0800 Subject: [PATCH 1/5] fix bug in load_state_dict_into_model; format error msg --- colossalai/checkpoint_io/general_checkpoint_io.py | 4 ++-- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- colossalai/checkpoint_io/utils.py | 6 +++--- colossalai/inference/core/plugin.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index b9253a56dcbb..af21ea0d19ba 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -220,9 +220,9 @@ def load_sharded_model( if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = "Missing key(s) in state_dict: {}. ".format( + error_msgs = ["Missing key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in missing_keys) - ) + )] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 043e5c2b0618..31ded1cccceb 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -381,9 +381,9 @@ def _load(name: str): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = "Missing key(s) in state_dict: {}. ".format( + error_msgs = ["Missing key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in missing_keys) - ) + )] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 36138f33e9ab..6eba3f9329c2 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -553,7 +553,7 @@ def load_state_dict_into_model( def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: @@ -570,9 +570,9 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) if strict: if len(unexpected_keys) > 0: - error_msgs = "Unexpected key(s) in state_dict: {}. ".format( + error_msgs = ["Unexpected key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in unexpected_keys) - ) + )] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) ) diff --git a/colossalai/inference/core/plugin.py b/colossalai/inference/core/plugin.py index d6a2b8b16550..e3ae836637b9 100644 --- a/colossalai/inference/core/plugin.py +++ b/colossalai/inference/core/plugin.py @@ -116,9 +116,9 @@ def _load(name: str): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = "Missing key(s) in state_dict: {}. ".format( + error_msgs = ["Missing key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in missing_keys) - ) + )] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) From dd43228674b4f46d8064ae0845d9530aaa7c08f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 03:25:57 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/general_checkpoint_io.py | 6 +++--- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 6 +++--- colossalai/checkpoint_io/utils.py | 6 +++--- colossalai/inference/core/plugin.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index af21ea0d19ba..a65dbe242ec7 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -220,9 +220,9 @@ def load_sharded_model( if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = ["Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - )] + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 31ded1cccceb..3b6917d32fa6 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -381,9 +381,9 @@ def _load(name: str): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = ["Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - )] + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6eba3f9329c2..c02c22829301 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -570,9 +570,9 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) if strict: if len(unexpected_keys) > 0: - error_msgs = ["Unexpected key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in unexpected_keys) - )] + error_msgs = [ + "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) ) diff --git a/colossalai/inference/core/plugin.py b/colossalai/inference/core/plugin.py index e3ae836637b9..ae526b888eee 100644 --- a/colossalai/inference/core/plugin.py +++ b/colossalai/inference/core/plugin.py @@ -116,9 +116,9 @@ def _load(name: str): remain_keys = remain_keys.union(set(missing_file_keys)) if len(remain_keys) > 0: if strict: - error_msgs = ["Missing key(s) in state_dict: {}. ".format( - ", ".join('"{}"'.format(k) for k in missing_keys) - )] + error_msgs = [ + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) From 40537c57f71ae58b5cb059da1791e2f304276c27 Mon Sep 17 00:00:00 2001 From: "Gao, Ruiyuan" <905370712@qq.com> Date: Tue, 20 Aug 2024 22:59:32 +0800 Subject: [PATCH 3/5] Update utils.py to support checking missing_keys --- colossalai/checkpoint_io/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index c02c22829301..b3917bd9d381 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -556,7 +556,7 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict - if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) if load_sub_module: for name, child in module._modules.items(): From 6c08a453904e5100f1b166a97ce98d149acb64d4 Mon Sep 17 00:00:00 2001 From: "Gao, Ruiyuan" <905370712@qq.com> Date: Tue, 20 Aug 2024 23:00:31 +0800 Subject: [PATCH 4/5] Update general_checkpoint_io.py fix bug in missing_keys error message --- colossalai/checkpoint_io/general_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a65dbe242ec7..2534fa163da1 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -221,7 +221,7 @@ def load_sharded_model( remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: error_msgs = [ - "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys)) ] raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( From bd124d7579997290bff70697117dc1a9fa8aebea Mon Sep 17 00:00:00 2001 From: flymin <905370712@qq.com> Date: Fri, 23 Aug 2024 11:23:16 +0800 Subject: [PATCH 5/5] retrigger tests