diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6f35179e96..4c17b32c29 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -528,7 +528,9 @@ def download( If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable. repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub". If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". - If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". + If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". Please note that + bundles for "monaihosting" source are also hosted on Hugging Face Hub, but the "repo_id" is always in the form + of "MONAI/bundle_name", therefore, this argument is not required for "monaihosting" source. If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name", or you can specify the environment variable NGC_ORG and NGC_TEAM. url: url to download the data. If not `None`, data will be downloaded directly @@ -600,11 +602,15 @@ def download( _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": try: + extract_path = os.path.join(bundle_dir_, name_) + huggingface_hub.snapshot_download(repo_id=f"MONAI/{name_}", revision=version_, local_dir=extract_path) + except (huggingface_hub.errors.RevisionNotFoundError, huggingface_hub.errors.RepositoryNotFoundError): + # if bundle or version not found from huggingface, download from ngc monaihosting _download_from_monaihosting( download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ ) except urllib.error.HTTPError: - # for monaihosting bundles, if cannot download from default host, download according to bundle_info + # if also cannot download from ngc monaihosting, download according to bundle_info _download_from_bundle_info( download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ ) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index da58a6313e..dc73ea6d99 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -63,6 +63,8 @@ TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"] +TEST_CASE_6_HF = [["models/model.pt", "configs/train.yaml"], "mednist_ddpm", "1.0.1"] + TEST_CASE_7 = [ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], "test_bundle", @@ -193,6 +195,7 @@ def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _ur @parameterized.expand([TEST_CASE_6]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version): with skip_if_downloading_fails(): # download a single file from url, also use `args_file` @@ -239,6 +242,7 @@ def test_list_latest_versions(self): self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) def test_download_monaihosting(self, mock_get_versions): """Test checking MONAI version from a metadata file.""" @@ -333,6 +337,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) @parameterized.expand([TEST_CASE_8]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") def test_load_weights_with_net_override(self, bundle_name, device, net_override): with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path