diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 322a2efc5b..2b1d3cd6f7 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -45,6 +45,7 @@ from monai.utils import ( check_parent_dir, deprecated_arg, + deprecated_arg_default, ensure_tuple, get_equivalent_dtype, min_version, @@ -61,6 +62,7 @@ logger = get_logger(module_name=__name__) # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download +# set BUNDLE_DOWNLOAD_SRC="monaihosting" to use monaihosting source in default for bundle download download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") PPRINT_CONFIG_N = 5 @@ -80,9 +82,9 @@ def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwa if isinstance(args, str): # args are defined in a structured file args_ = ConfigParser.load_config_file(args) - # recursively update the default args with new args for k, v in kwargs.items(): + print(k, v) if ignore_none and v is None: continue if isinstance(v, dict) and isinstance(args_.get(k), dict): @@ -156,6 +158,11 @@ def _get_ngc_bundle_url(model_name: str, version: str) -> str: return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name}/versions/{version}/zip" +def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: + monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" + return f"{monaihosting_root_path}/{model_name}/versions/{version}/files/{model_name}_v{version}.zip" + + def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: repo_owner, repo_name, tag_name = repo.split("/") if ".zip" not in filename: @@ -166,6 +173,13 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres extractall(filepath=filepath, output_dir=download_path, has_base=True) +def _download_from_monaihosting(download_path: Path, filename: str, version: str, progress: bool) -> None: + url = _get_monaihosting_bundle_url(model_name=filename, version=version) + filepath = download_path / f"{filename}_v{version}.zip" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name @@ -192,6 +206,19 @@ def _download_from_ngc( extractall(filepath=filepath, output_dir=extract_path, has_base=True) +def _get_latest_bundle_version_monaihosting(name): + url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" + full_url = f"{url}/{name}" + requests_get, has_requests = optional_import("requests", name="get") + if has_requests: + resp = requests_get(full_url) + resp.raise_for_status() + else: + raise ValueError("NGC API requires requests package. Please install it.") + model_info = json.loads(resp.text) + return model_info["model"]["latestVersionIdStr"] + + def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) @@ -200,11 +227,15 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l if v["name"] == name: return v["latest"] return None + elif source == "monaihosting": + return _get_latest_bundle_version_monaihosting(name) elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] else: - raise ValueError(f"To get the latest bundle version, source should be 'github' or 'ngc', got {source}.") + raise ValueError( + f"To get the latest bundle version, source should be 'github', 'monaihosting' or 'ngc', got {source}." + ) def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: @@ -217,6 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: return Path(bundle_dir) +@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4") def download( name: str | None = None, version: str | None = None, @@ -247,6 +279,9 @@ def download( # Execute this module as a CLI entry, and download bundle from ngc with latest version: python -m monai.bundle download --name --source "ngc" --bundle_dir "./" + # Execute this module as a CLI entry, and download bundle from monaihosting with latest version: + python -m monai.bundle download --name --source "monaihosting" --bundle_dir "./" + # Execute this module as a CLI entry, and download bundle via URL: python -m monai.bundle download --name --url @@ -270,7 +305,7 @@ def download( Default is `bundle` subfolder under `torch.hub.get_dir()`. source: storage location name. This argument is used when `url` is `None`. In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc" or "github". + it should be "ngc", "monaihosting" or "github". repo: repo name. This argument is used when `url` is `None` and `source` is "github". If used, it should be in the form of "repo_owner/repo_name/release_tag". url: url to download the data. If not `None`, data will be downloaded directly @@ -324,6 +359,8 @@ def download( if version_ is not None: name_ = "_v".join([name_, version_]) _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + elif source_ == "monaihosting": + _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) elif source_ == "ngc": _download_from_ngc( download_path=bundle_dir_, @@ -334,7 +371,8 @@ def download( ) else: raise NotImplementedError( - f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}." + "Currently only download from `url`, source 'github', 'monaihosting' or 'ngc' are implemented," + f"got source: {source_}." ) @@ -374,7 +412,7 @@ def load( source: storage location name. This argument is used when `model_file` is not existing locally and need to be downloaded first. In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc" or "github". + it should be "ngc", "monaihosting" or "github". repo: repo name. This argument is used when `url` is `None` and `source` is "github". If used, it should be in the form of "repo_owner/repo_name/release_tag". remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 52aa515111..36e935bf08 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -58,6 +58,12 @@ "model.ts", ] +TEST_CASE_6 = [ + ["models/model.pt", "models/model.ts", "configs/train.json"], + "brats_mri_segmentation", + "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip", +] + class TestDownload(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -91,7 +97,7 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): parser = ConfigParser() parser.export_config_file(config=def_args, filepath=def_args_file) cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] - cmd += ["--url", url] + cmd += ["--url", url, "--source", "github"] command_line_tests(cmd) for file in bundle_files: file_path = os.path.join(tempdir, bundle_name, file) @@ -99,6 +105,23 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): if file == "network.json": self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + @parameterized.expand([TEST_CASE_6]) + @skip_if_quick + def test_monaihosting_download_bundle(self, bundle_files, bundle_name, url): + with skip_if_downloading_fails(): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url, "--progress", "False", "--source", "monaihosting"] + command_line_tests(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + class TestLoad(unittest.TestCase): @parameterized.expand([TEST_CASE_4]) @@ -113,6 +136,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) model_file=model_file, bundle_dir=tempdir, repo=repo, + source="github", progress=False, device=device, ) @@ -142,6 +166,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) progress=False, device=device, net_name=model_name, + source="github", **net_args, ) model_2.eval() @@ -165,6 +190,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, repo=repo, progress=False, device=device, + source="github", config_files=("network.json",), )