diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 56146546e8..4967b6cf50 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -27,7 +27,7 @@ import torch from torch.cuda import is_available -from monai.apps.mmars.mmars import _get_all_ngc_models +from monai._version import get_versions from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser @@ -67,6 +67,9 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 +MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" +NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" + def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ @@ -169,16 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" + + +def _get_ngc_private_base_url(repo: str) -> str: + return f"https://api.ngc.nvidia.com/v2/{repo}/models" def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str: - return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip" + return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip" def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: - monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" + return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: @@ -267,8 +273,7 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): - url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name.lower()}" + full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) @@ -279,18 +284,100 @@ def _get_latest_bundle_version_monaihosting(name): return model_info["model"]["latestVersionIdStr"] -def _get_latest_bundle_version_private_registry(name, repo, headers=None): - url = f"https://api.ngc.nvidia.com/v2/{repo}/models" - full_url = f"{url}/{name.lower()}" - requests_get, has_requests = optional_import("requests", name="get") - if has_requests: - headers = {} if headers is None else headers - resp = requests_get(full_url, headers=headers) - resp.raise_for_status() - else: - raise ValueError("NGC API requires requests package. Please install it.") +def _examine_monai_version(monai_version: str) -> tuple[bool, str]: + """Examine if the package version is compatible with the MONAI version in the metadata.""" + version_dict = get_versions() + package_version = version_dict.get("version", "0+unknown") + if package_version == "0+unknown": + return False, "Package version is not available. Skipping version check." + if monai_version == "0+unknown": + return False, "MONAI version is not specified in the bundle. Skipping version check." + # treat rc versions as the same as the release version + package_version = re.sub(r"rc\d.*", "", package_version) + monai_version = re.sub(r"rc\d.*", "", monai_version) + if package_version < monai_version: + return ( + False, + f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.", + ) + return True, "" + + +def _check_monai_version(bundle_dir: PathLike, name: str) -> None: + """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version""" + metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json" + if not metadata_file.exists(): + logger.warning(f"metadata file not found in {metadata_file}.") + return + with open(metadata_file) as f: + metadata = json.load(f) + is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown")) + if not is_compatible: + logger.warning(msg) + + +def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: + """ + Extract the latest versions from the data dictionary. + + Args: + data: the data dictionary. + max_versions: the maximum number of versions to return. + + Returns: + versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0']. + """ + # Check if the data is a dictionary and it has the key 'modelVersions' + if not isinstance(data, dict) or "modelVersions" not in data: + raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.") + + # Extract the list of model versions + model_versions = data["modelVersions"] + + if ( + not isinstance(model_versions, list) + or len(model_versions) == 0 + or "createdDate" not in model_versions[0] + or "versionId" not in model_versions[0] + ): + raise ValueError( + "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'." + ) + + # Sort the versions by the 'createdDate' in descending order + sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True) + return [v["versionId"] for v in sorted_versions[:max_versions]] + + +def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str: + base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL + version_endpoint = base_url + f"/{name.lower()}/versions/" + + if not has_requests: + raise ValueError("requests package is required, please install it.") + + version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements + if headers: + version_header.update(headers) + resp = requests_get(version_endpoint, headers=version_header) + resp.raise_for_status() model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + latest_versions = _list_latest_versions(model_info) + + for version in latest_versions: + file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" + resp = requests_get(file_endpoint, headers=headers) + metadata = json.loads(resp.text) + resp.raise_for_status() + # if the package version is not available or the model is compatible with the package version + is_compatible, _ = _examine_monai_version(metadata["monai_version"]) + if is_compatible: + if version != latest_versions[0]: + logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.") + return version + + # if no compatible version is found, return the latest version + return latest_versions[0] def _get_latest_bundle_version( @@ -298,17 +385,13 @@ def _get_latest_bundle_version( ) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) - model_dict = _get_all_ngc_models(name) - for v in model_dict.values(): - if v["name"] == name: - return v["latest"] - return None + return _get_latest_bundle_version_ngc(name) elif source == "monaihosting": return _get_latest_bundle_version_monaihosting(name) elif source == "ngc_private": headers = kwargs.pop("headers", {}) name = _add_ngc_prefix(name) - return _get_latest_bundle_version_private_registry(name, repo, headers) + return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers) elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] @@ -470,9 +553,8 @@ def download( if version_ is None: version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers) if source_ == "github": - if version_ is not None: - name_ = "_v".join([name_, version_]) - _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + name_ver = "_v".join([name_, version_]) if version_ is not None else name_ + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) elif source_ == "ngc": @@ -501,6 +583,8 @@ def download( f"got source: {source_}." ) + _check_monai_version(bundle_dir_, name_) + @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index fe7caf5c17..331d228f1e 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -16,6 +16,7 @@ import tempfile import unittest from unittest.case import skipUnless +from unittest.mock import patch import numpy as np import torch @@ -24,6 +25,7 @@ import monai.networks.nets as nets from monai.apps import check_hash from monai.bundle import ConfigParser, create_workflow, load +from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download from monai.utils import optional_import from tests.utils import ( SkipIfBeforePyTorchVersion, @@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve file_path = os.path.join(tempdir, bundle_name, file) self.assertTrue(os.path.exists(file_path)) + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_examine_monai_version(self, mock_get_versions): + self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible + self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible + self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible + + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"}) + def test_examine_monai_version_rc(self, mock_get_versions): + self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible + self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible + + def test_list_latest_versions(self): + """Test listing of the latest versions.""" + data = { + "modelVersions": [ + {"createdDate": "2021-01-01", "versionId": "1.0"}, + {"createdDate": "2021-01-02", "versionId": "1.1"}, + {"createdDate": "2021-01-03", "versionId": "1.2"}, + ] + } + self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"]) + self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"]) + data = { + "modelVersions": [ + {"createdDate": "2021-01-01", "versionId": "1.0"}, + {"createdDate": "2021-01-02", "versionId": "1.1"}, + ] + } + self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"]) + + @skip_if_quick + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_download_monaihosting(self, mock_get_versions): + """Test checking MONAI version from a metadata file.""" + with patch("monai.bundle.scripts.logger") as mock_logger: + with tempfile.TemporaryDirectory() as tempdir: + download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting") + # Should have a warning message because the latest version is using monai > 1.2 + mock_logger.warning.assert_called_once() + + @skip_if_quick + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_download_ngc(self, mock_get_versions): + """Test checking MONAI version from a metadata file.""" + with patch("monai.bundle.scripts.logger") as mock_logger: + with tempfile.TemporaryDirectory() as tempdir: + download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc") + mock_logger.warning.assert_not_called() + @skip_if_no_cuda class TestLoad(unittest.TestCase):