From a188d3a053372241300127c9877e74cc2f6afabc Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 26 Mar 2025 12:08:46 +0800 Subject: [PATCH 1/3] update bundle download api Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 6 +++++- monai/utils/jupyter_utils.py | 2 +- tests/bundle/test_bundle_download.py | 5 +++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6f35179e96..9ef77cb070 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -600,11 +600,15 @@ def download( _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": try: + extract_path = os.path.join(bundle_dir_, name_) + huggingface_hub.snapshot_download(repo_id=f"MONAI/{name_}", revision=version_, local_dir=extract_path) + except (huggingface_hub.errors.RevisionNotFoundError, huggingface_hub.errors.RepositoryNotFoundError): + # if bundle or version not found from huggingface, download from ngc monaihosting _download_from_monaihosting( download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ ) except urllib.error.HTTPError: - # for monaihosting bundles, if cannot download from default host, download according to bundle_info + # if also cannot download from ngc monaihosting, download according to bundle_info _download_from_bundle_info( download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ ) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index c93e93dcb9..b1b43a6767 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index da58a6313e..dc73ea6d99 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -63,6 +63,8 @@ TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"] +TEST_CASE_6_HF = [["models/model.pt", "configs/train.yaml"], "mednist_ddpm", "1.0.1"] + TEST_CASE_7 = [ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], "test_bundle", @@ -193,6 +195,7 @@ def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _ur @parameterized.expand([TEST_CASE_6]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version): with skip_if_downloading_fails(): # download a single file from url, also use `args_file` @@ -239,6 +242,7 @@ def test_list_latest_versions(self): self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) def test_download_monaihosting(self, mock_get_versions): """Test checking MONAI version from a metadata file.""" @@ -333,6 +337,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) @parameterized.expand([TEST_CASE_8]) @skip_if_quick + @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") def test_load_weights_with_net_override(self, bundle_name, device, net_override): with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path From 9a6a1b6cfdbc747061af0abc50b097986e876e81 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 26 Mar 2025 12:23:17 +0800 Subject: [PATCH 2/3] revert autofix change Signed-off-by: Yiheng Wang --- monai/utils/jupyter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index b1b43a6767..c93e93dcb9 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" From ab336b10a9a3593dc50405e1a694808fd4c93d0c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 8 Apr 2025 11:26:56 +0800 Subject: [PATCH 3/3] add doc Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9ef77cb070..4c17b32c29 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -528,7 +528,9 @@ def download( If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable. repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub". If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". - If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". + If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". Please note that + bundles for "monaihosting" source are also hosted on Hugging Face Hub, but the "repo_id" is always in the form + of "MONAI/bundle_name", therefore, this argument is not required for "monaihosting" source. If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name", or you can specify the environment variable NGC_ORG and NGC_TEAM. url: url to download the data. If not `None`, data will be downloaded directly