From a601d53ba7bb944c0480941365b344a777ed30ca Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 5 Jul 2024 23:29:49 +0200 Subject: [PATCH 1/4] fix load sharded checkpoints from subfolder{ --- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/utils/hub_utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5604879f40ab..ebd356d981d6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -221,7 +221,7 @@ def _fetch_index_file( local_files_only=local_files_only, token=token, revision=revision, - subfolder=subfolder, + subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ce90fb09193b..428d1f340d08 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,10 +455,13 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames + if subfolder is not None: + allow_patterns = [subfolder+ "/" + p for p in allow_patterns] + ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path) + model_files_info = model_info(pretrained_model_name_or_path, revision=revision) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: @@ -481,6 +484,8 @@ def _get_checkpoint_shard_files( ignore_patterns=ignore_patterns, user_agent=user_agent, ) + if subfolder is not None: + cached_folder = cached_folder + "/" + subfolder # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. From b09df5b071ff31bf1d401c8f16bc912d4b96eb58 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 5 Jul 2024 23:31:06 +0200 Subject: [PATCH 2/4] style --- src/diffusers/utils/hub_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 428d1f340d08..ae73859143aa 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -456,7 +456,7 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames if subfolder is not None: - allow_patterns = [subfolder+ "/" + p for p in allow_patterns] + allow_patterns = [subfolder + "/" + p for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] if not local_files_only: From 2ab02244db4f7935a12b2873f64b3a9944240d88 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Jul 2024 07:15:26 +0530 Subject: [PATCH 3/4] os.path.join --- src/diffusers/utils/hub_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ae73859143aa..7ecb7de89cd3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -456,7 +456,7 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames if subfolder is not None: - allow_patterns = [subfolder + "/" + p for p in allow_patterns] + allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] if not local_files_only: @@ -485,7 +485,7 @@ def _get_checkpoint_shard_files( user_agent=user_agent, ) if subfolder is not None: - cached_folder = cached_folder + "/" + subfolder + cached_folder = os.path.join(cached_folder, subfolder) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. From 6ee794fac0af7ee9fa1c9fd3f33494646907dfef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Jul 2024 07:39:31 +0530 Subject: [PATCH 4/4] add a small test --- tests/models/unets/test_models_unet_2d_condition.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 63e66dabf0c8..a84968e613b5 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1045,6 +1045,18 @@ def test_load_sharded_checkpoint_from_hub(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" + ) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common()