Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 111 additions & 27 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from torch.cuda import is_available

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai._version import get_versions
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
Expand Down Expand Up @@ -67,6 +67,9 @@
DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
PPRINT_CONFIG_N = 5

MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"


def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
"""
Expand Down Expand Up @@ -169,16 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"


def _get_ngc_private_base_url(repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models"


def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"


def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
Expand Down Expand Up @@ -267,8 +273,7 @@ def _get_ngc_token(api_key, retry=0):


def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
full_url = f"{url}/{name.lower()}"
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
Expand All @@ -279,36 +284,114 @@ def _get_latest_bundle_version_monaihosting(name):
return model_info["model"]["latestVersionIdStr"]


def _get_latest_bundle_version_private_registry(name, repo, headers=None):
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
headers = {} if headers is None else headers
resp = requests_get(full_url, headers=headers)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
"""Examine if the package version is compatible with the MONAI version in the metadata."""
version_dict = get_versions()
package_version = version_dict.get("version", "0+unknown")
if package_version == "0+unknown":
return False, "Package version is not available. Skipping version check."
if monai_version == "0+unknown":
return False, "MONAI version is not specified in the bundle. Skipping version check."
# treat rc versions as the same as the release version
package_version = re.sub(r"rc\d.*", "", package_version)
monai_version = re.sub(r"rc\d.*", "", monai_version)
if package_version < monai_version:
return (
False,
f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
)
return True, ""


def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
"""Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
if not metadata_file.exists():
logger.warning(f"metadata file not found in {metadata_file}.")
return
with open(metadata_file) as f:
metadata = json.load(f)
is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
if not is_compatible:
logger.warning(msg)


def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
"""
Extract the latest versions from the data dictionary.

Args:
data: the data dictionary.
max_versions: the maximum number of versions to return.

Returns:
versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
"""
# Check if the data is a dictionary and it has the key 'modelVersions'
if not isinstance(data, dict) or "modelVersions" not in data:
raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")

# Extract the list of model versions
model_versions = data["modelVersions"]

if (
not isinstance(model_versions, list)
or len(model_versions) == 0
or "createdDate" not in model_versions[0]
or "versionId" not in model_versions[0]
):
raise ValueError(
"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
)

# Sort the versions by the 'createdDate' in descending order
sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
return [v["versionId"] for v in sorted_versions[:max_versions]]


def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
version_endpoint = base_url + f"/{name.lower()}/versions/"

if not has_requests:
raise ValueError("requests package is required, please install it.")

version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
if headers:
version_header.update(headers)
resp = requests_get(version_endpoint, headers=version_header)
resp.raise_for_status()
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
latest_versions = _list_latest_versions(model_info)

for version in latest_versions:
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
resp = requests_get(file_endpoint, headers=headers)
metadata = json.loads(resp.text)
resp.raise_for_status()
# if the package version is not available or the model is compatible with the package version
is_compatible, _ = _examine_monai_version(metadata["monai_version"])
if is_compatible:
if version != latest_versions[0]:
logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
return version

# if no compatible version is found, return the latest version
return latest_versions[0]


def _get_latest_bundle_version(
source: str, name: str, repo: str, **kwargs: Any
) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
model_dict = _get_all_ngc_models(name)
for v in model_dict.values():
if v["name"] == name:
return v["latest"]
return None
return _get_latest_bundle_version_ngc(name)
elif source == "monaihosting":
return _get_latest_bundle_version_monaihosting(name)
elif source == "ngc_private":
headers = kwargs.pop("headers", {})
name = _add_ngc_prefix(name)
return _get_latest_bundle_version_private_registry(name, repo, headers)
return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
Expand Down Expand Up @@ -470,9 +553,8 @@ def download(
if version_ is None:
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
if source_ == "github":
if version_ is not None:
name_ = "_v".join([name_, version_])
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
elif source_ == "monaihosting":
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
elif source_ == "ngc":
Expand Down Expand Up @@ -501,6 +583,8 @@ def download(
f"got source: {source_}."
)

_check_monai_version(bundle_dir_, name_)


@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
Expand Down
51 changes: 51 additions & 0 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import unittest
from unittest.case import skipUnless
from unittest.mock import patch

import numpy as np
import torch
Expand All @@ -24,6 +25,7 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, create_workflow, load
from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
from monai.utils import optional_import
from tests.utils import (
SkipIfBeforePyTorchVersion,
Expand Down Expand Up @@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_examine_monai_version(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible
self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"})
def test_examine_monai_version_rc(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

def test_list_latest_versions(self):
"""Test listing of the latest versions."""
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
{"createdDate": "2021-01-03", "versionId": "1.2"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"])
self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"])
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_monaihosting(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting")
# Should have a warning message because the latest version is using monai > 1.2
mock_logger.warning.assert_called_once()

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_ngc(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc")
mock_logger.warning.assert_not_called()


@skip_if_no_cuda
class TestLoad(unittest.TestCase):
Expand Down