Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def load_sharded_model(
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def _load(name: str):
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down
10 changes: 5 additions & 5 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ def load_state_dict_into_model(

def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
Expand All @@ -570,9 +570,9 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True)

if strict:
if len(unexpected_keys) > 0:
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
)
error_msgs = [
"Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def _load(name: str):
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down