Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
integration-py3:
container:
image: nvcr.io/nvidia/pytorch:21.12-py3 # CUDA 11.5
options: --gpus all
options: --gpus all # shm-size 4g works fine
runs-on: [self-hosted, linux, x64, common]
steps:
# checkout the pull request branch
Expand Down
2 changes: 2 additions & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Model Bundle
`Scripts`
---------
.. autofunction:: ckpt_export
.. autofunction:: download
.. autofunction:: load
.. autofunction:: run
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,13 @@ Spatial
:members:
:special-members: __call__

`GridSplit`
"""""""""""
.. autoclass:: GridSplit
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^

Expand Down Expand Up @@ -1506,6 +1513,13 @@ Spatial (Dict)
:members:
:special-members: __call__

`GridSplitd`
""""""""""""
.. autoclass:: GridSplitd
:members:
:special-members: __call__


`RandRotate90d`
"""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out
from .scripts import ckpt_export, download, load, run, verify_metadata, verify_net_in_out
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
2 changes: 1 addition & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out
from monai.bundle.scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
205 changes: 199 additions & 6 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@

import ast
import json
import os
import pprint
import re
from logging.config import fileConfig
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
from torch.cuda import is_available

from monai.apps.utils import download_url, get_logger
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.config import IgniteInfo, PathLike
from monai.data import save_net_with_metadata
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
from monai.utils.misc import ensure_tuple

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -116,6 +121,182 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int
return tuple(ret)


def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str):
return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True):
if len(repo.split("/")) != 3:
raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.")
repo_owner, repo_name, tag_name = repo.split("/")
if ".zip" not in filename:
filename += ".zip"
url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename)
filepath = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=download_path, has_base=True)


def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
if bundle_dir is None:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
if has_home:
bundle_dir = Path(get_dir()) / "bundle"
else:
raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")
return Path(bundle_dir)


def download(
name: Optional[str] = None,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
url: Optional[str] = None,
progress: bool = True,
args_file: Optional[str] = None,
):
"""
download bundle from the specified source or url. The bundle should be a zip file and it
will be extracted after downloading.
This function refers to:
https://pytorch.org/docs/stable/_modules/torch/hub.html

Typical usage examples:

.. code-block:: bash

# Execute this module as a CLI entry, and download bundle:
python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag"

# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name "bundle_name" --url <url>

# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime.
# The content of the JSON / YAML file is a dictionary. For example:
# {"name": "spleen", "bundle_dir": "download", "source": ""}
# then do the following command for downloading:
python -m monai.bundle download --args_file "args.json" --source "github"

Args:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
progress: whether to display a progress bar.
args_file: a JSON or YAML file to provide default values for all the args in this function.
so that the command line inputs can be simplified.

"""
_args = _update_args(
args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress
)

_log_input_summary(tag="download", args=_args)
name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args(
_args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True
)

bundle_dir_ = _process_bundle_dir(bundle_dir_)

if url_ is not None:
if name is not None:
filepath = bundle_dir_ / f"{name}.zip"
else:
filepath = bundle_dir_ / f"{_basename(url_)}"
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
elif source_ == "github":
if name_ is None or repo_ is None:
raise ValueError(
f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}."
)
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
else:
raise NotImplementedError(
f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}."
)


def load(
name: str,
model_file: Optional[str] = None,
load_ts_module: bool = False,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
progress: bool = True,
device: Optional[str] = None,
config_files: Sequence[str] = (),
net_name: Optional[str] = None,
**net_kwargs,
):
"""
Load model weights or TorchScript module of a bundle.

Args:
name: bundle name.
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
bundle_dir: the directory the weights/TorchScript module will be loaded from.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: the place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
see `_extra_files` in `torch.jit.load` for more details.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
This argument only works when loading weights.
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

Returns:
1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights.
2. If `load_ts_module` is `False` and `net_name` is not `None`,
return an instantiated network that loaded the weights.
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.

"""
bundle_dir_ = _process_bundle_dir(bundle_dir)

if model_file is None:
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path):
download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)

if device is None:
device = "cuda:0" if is_available() else "cpu"
# loading with `torch.jit.load`
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))

if net_name is None:
return model_dict
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate()
model.to(device) # type: ignore
model.load_state_dict(model_dict) # type: ignore
return model


def run(
runner_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -249,7 +430,7 @@ def verify_metadata(
try:
# the rest key-values in the _args are for `validate` API
validate(instance=metadata, schema=schema, **_args)
except ValidationError as e:
except ValidationError as e: # pylint: disable=E0712
# as the error message is very long, only extract the key information
logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.")
return
Expand Down Expand Up @@ -370,8 +551,10 @@ def ckpt_export(
filepath: filepath to export, if filename has no extension it becomes `.ts`.
ckpt_file: filepath of the model checkpoint to load.
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file to save in TorchScript model and extract network information,
the saved key in the TorchScript model is the config filename without extension, and the saved config
value is always serialized in JSON format no matter the original file format is JSON or YAML.
it can be a single file or a list of files. if `None`, must be provided in `args_file`.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
weights. if not nested checkpoint, no need to set.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
Expand Down Expand Up @@ -415,12 +598,22 @@ def ckpt_export(
# convert to TorchScript model and save with meta data, config content
net = convert_to_torchscript(model=net)

extra_files: Dict = {}
for i in ensure_tuple(config_file_):
# split the filename and directory
filename = os.path.basename(i)
# remove extension
filename, _ = os.path.splitext(filename)
if filename in extra_files:
raise ValueError(f"filename '{filename}' is given multiple times in config file list.")
extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode()

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath_,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files={"config": json.dumps(parser.get()).encode()},
more_extra_files=extra_files,
)
logger.info(f"exported to TorchScript file: {filepath_}.")
27 changes: 10 additions & 17 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,16 @@ def get_data(
f"The image dimension should be 3 but has {patch.ndim}. "
"`WSIReader` is designed to work only with 2D images with color channel."
)

# Check if there are four color channels for RGBA
if mode == "RGBA" and patch.shape[0] != 4:
raise ValueError(
f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}."
)
# Check if there are three color channels for RGB
elif mode in "RGB" and patch.shape[0] != 3:
raise ValueError(
f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[0]}. "
)
# Create a list of patches
patch_list.append(patch)

Expand Down Expand Up @@ -408,11 +417,6 @@ def get_patch(
patch = AsChannelFirst()(patch) # type: ignore

# Check if the color channel is 3 (RGB) or 4 (RGBA)
if mode == "RGBA" and patch.shape[0] != 4:
raise ValueError(
f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}."
)

if mode in "RGB":
if patch.shape[0] not in [3, 4]:
raise ValueError(
Expand Down Expand Up @@ -537,15 +541,4 @@ def get_patch(
# Make it channel first
patch = AsChannelFirst()(patch) # type: ignore

# Check if the color channel is 3 (RGB) or 4 (RGBA)
if mode == "RGBA" and patch.shape[0] != 4:
raise ValueError(
f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}."
)

elif mode in "RGB" and patch.shape[0] != 3:
raise ValueError(
f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[0]}. "
)

return patch
Loading