From f8834482d11dfef3cd72ce957588201ff03d8ac9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Mar 2025 12:12:41 +0000 Subject: [PATCH 1/2] don't gc collect if 1 shard is used --- src/transformers/modeling_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 77f842aa5f6d..b58b484b3a5c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4817,6 +4817,7 @@ def _load_pretrained_model( error_msgs = [] mismatched_keys = [] + has_multiple_shards = len(checkpoint_files) > 1 # Iterate on all the shards to load the weights for shard_file in checkpoint_files: # Skip the load for shards that only contain disk-offloaded weights @@ -4835,7 +4836,7 @@ def _load_pretrained_model( ): map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) - # If shard_file is""", we use the existing state_dict instead of loading it + # If shard_file is "", we use the existing state_dict instead of loading it if shard_file != "": state_dict = load_state_dict( shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only @@ -4881,9 +4882,10 @@ def _load_pretrained_model( else: model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) - # force memory release - del state_dict - gc.collect() + # force memory release if loading multiple shards + if has_multiple_shards: + del state_dict + gc.collect() # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: From 9dac0a4b9137df727623446852ef5bbc307f5630 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Mar 2025 12:34:21 +0000 Subject: [PATCH 2/2] delete state dict anyways --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b58b484b3a5c..40c675c70bf6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4882,9 +4882,9 @@ def _load_pretrained_model( else: model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) + del state_dict # force memory release if loading multiple shards if has_multiple_shards: - del state_dict gc.collect() # Adjust offloaded weights name and save if needed