diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c38b66eda4..767b58a792 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -8,7 +8,7 @@ jobs: integration-py3: container: image: nvcr.io/nvidia/pytorch:21.12-py3 # CUDA 11.5 - options: --gpus all + options: --gpus all # shm-size 4g works fine runs-on: [self-hosted, linux, x64, common] steps: # checkout the pull request branch diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 297409cd7e..a28db04091 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -36,6 +36,8 @@ Model Bundle `Scripts` --------- .. autofunction:: ckpt_export +.. autofunction:: download +.. autofunction:: load .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 676e0274fe..a93c48984c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -737,6 +737,13 @@ Spatial :members: :special-members: __call__ +`GridSplit` +""""""""""" +.. autoclass:: GridSplit + :members: + :special-members: __call__ + + Smooth Field ^^^^^^^^^^^^ @@ -1506,6 +1513,13 @@ Spatial (Dict) :members: :special-members: __call__ +`GridSplitd` +"""""""""""" +.. autoclass:: GridSplitd + :members: + :special-members: __call__ + + `RandRotate90d` """"""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index d6a452b5a4..f30ee9c40c 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, download, load, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index d77b396e79..3e3534ef74 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b741e40e8d..00496dba66 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,24 +11,29 @@ import ast import json +import os import pprint import re from logging.config import fileConfig +from pathlib import Path from typing import Dict, Optional, Sequence, Tuple, Union import torch from torch.cuda import is_available -from monai.apps.utils import download_url, get_logger +from monai.apps.utils import _basename, download_url, extractall, get_logger +from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike -from monai.data import save_net_with_metadata +from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import +from monai.utils.misc import ensure_tuple validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +requests_get, has_requests = optional_import("requests", name="get") logger = get_logger(module_name=__name__) @@ -116,6 +121,182 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int return tuple(ret) +def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str): + return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" + + +def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True): + if len(repo.split("/")) != 3: + raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.") + repo_owner, repo_name, tag_name = repo.split("/") + if ".zip" not in filename: + filename += ".zip" + url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) + filepath = download_path / f"{filename}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + +def _process_bundle_dir(bundle_dir: Optional[PathLike] = None): + if bundle_dir is None: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + return Path(bundle_dir) + + +def download( + name: Optional[str] = None, + bundle_dir: Optional[PathLike] = None, + source: str = "github", + repo: Optional[str] = None, + url: Optional[str] = None, + progress: bool = True, + args_file: Optional[str] = None, +): + """ + download bundle from the specified source or url. The bundle should be a zip file and it + will be extracted after downloading. + This function refers to: + https://pytorch.org/docs/stable/_modules/torch/hub.html + + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry, and download bundle: + python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag" + + # Execute this module as a CLI entry, and download bundle via URL: + python -m monai.bundle download --name "bundle_name" --url + + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. + # Other args still can override the default args at runtime. + # The content of the JSON / YAML file is a dictionary. For example: + # {"name": "spleen", "bundle_dir": "download", "source": ""} + # then do the following command for downloading: + python -m monai.bundle download --args_file "args.json" --source "github" + + Args: + name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. + bundle_dir: target directory to store the downloaded data. + Default is `bundle` subfolder under`torch.hub get_dir()`. + source: place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. + repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + url: url to download the data. If not `None`, data will be downloaded directly + and `source` will not be checked. + If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. + progress: whether to display a progress bar. + args_file: a JSON or YAML file to provide default values for all the args in this function. + so that the command line inputs can be simplified. + + """ + _args = _update_args( + args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress + ) + + _log_input_summary(tag="download", args=_args) + name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( + _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True + ) + + bundle_dir_ = _process_bundle_dir(bundle_dir_) + + if url_ is not None: + if name is not None: + filepath = bundle_dir_ / f"{name}.zip" + else: + filepath = bundle_dir_ / f"{_basename(url_)}" + download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) + extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) + elif source_ == "github": + if name_ is None or repo_ is None: + raise ValueError( + f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}." + ) + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + else: + raise NotImplementedError( + f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}." + ) + + +def load( + name: str, + model_file: Optional[str] = None, + load_ts_module: bool = False, + bundle_dir: Optional[PathLike] = None, + source: str = "github", + repo: Optional[str] = None, + progress: bool = True, + device: Optional[str] = None, + config_files: Sequence[str] = (), + net_name: Optional[str] = None, + **net_kwargs, +): + """ + Load model weights or TorchScript module of a bundle. + + Args: + name: bundle name. + model_file: the relative path of the model weights or TorchScript module within bundle. + If `None`, "models/model.pt" or "models/model.ts" will be used. + load_ts_module: a flag to specify if loading the TorchScript module. + bundle_dir: the directory the weights/TorchScript module will be loaded from. + Default is `bundle` subfolder under`torch.hub get_dir()`. + source: the place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. + repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + progress: whether to display a progress bar when downloading. + device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. + config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, + see `_extra_files` in `torch.jit.load` for more details. + net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + This argument only works when loading weights. + net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. + + Returns: + 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. + 2. If `load_ts_module` is `False` and `net_name` is not `None`, + return an instantiated network that loaded the weights. + 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + the corresponding metadata dict, and extra files dict. + please check `monai.data.load_net_with_metadata` for more details. + + """ + bundle_dir_ = _process_bundle_dir(bundle_dir) + + if model_file is None: + model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + full_path = os.path.join(bundle_dir_, name, model_file) + if not os.path.exists(full_path): + download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress) + + if device is None: + device = "cuda:0" if is_available() else "cpu" + # loading with `torch.jit.load` + if load_ts_module is True: + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) + # loading with `torch.load` + model_dict = torch.load(full_path, map_location=torch.device(device)) + + if net_name is None: + return model_dict + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() + model.to(device) # type: ignore + model.load_state_dict(model_dict) # type: ignore + return model + + def run( runner_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, @@ -249,7 +430,7 @@ def verify_metadata( try: # the rest key-values in the _args are for `validate` API validate(instance=metadata, schema=schema, **_args) - except ValidationError as e: + except ValidationError as e: # pylint: disable=E0712 # as the error message is very long, only extract the key information logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return @@ -370,8 +551,10 @@ def ckpt_export( filepath: filepath to export, if filename has no extension it becomes `.ts`. ckpt_file: filepath of the model checkpoint to load. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. - config_file: filepath of the config file, if `None`, must be provided in `args_file`. - if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file to save in TorchScript model and extract network information, + the saved key in the TorchScript model is the config filename without extension, and the saved config + value is always serialized in JSON format no matter the original file format is JSON or YAML. + it can be a single file or a list of files. if `None`, must be provided in `args_file`. key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model weights. if not nested checkpoint, no need to set. args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, @@ -415,12 +598,22 @@ def ckpt_export( # convert to TorchScript model and save with meta data, config content net = convert_to_torchscript(model=net) + extra_files: Dict = {} + for i in ensure_tuple(config_file_): + # split the filename and directory + filename = os.path.basename(i) + # remove extension + filename, _ = os.path.splitext(filename) + if filename in extra_files: + raise ValueError(f"filename '{filename}' is given multiple times in config file list.") + extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + save_net_with_metadata( jit_obj=net, filename_prefix_or_stream=filepath_, include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), - more_extra_files={"config": json.dumps(parser.get()).encode()}, + more_extra_files=extra_files, ) logger.info(f"exported to TorchScript file: {filepath_}.") diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index ad5141787c..02032a0ae6 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -180,7 +180,16 @@ def get_data( f"The image dimension should be 3 but has {patch.ndim}. " "`WSIReader` is designed to work only with 2D images with color channel." ) - + # Check if there are four color channels for RGBA + if mode == "RGBA" and patch.shape[0] != 4: + raise ValueError( + f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." + ) + # Check if there are three color channels for RGB + elif mode in "RGB" and patch.shape[0] != 3: + raise ValueError( + f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[0]}. " + ) # Create a list of patches patch_list.append(patch) @@ -408,11 +417,6 @@ def get_patch( patch = AsChannelFirst()(patch) # type: ignore # Check if the color channel is 3 (RGB) or 4 (RGBA) - if mode == "RGBA" and patch.shape[0] != 4: - raise ValueError( - f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." - ) - if mode in "RGB": if patch.shape[0] not in [3, 4]: raise ValueError( @@ -537,15 +541,4 @@ def get_patch( # Make it channel first patch = AsChannelFirst()(patch) # type: ignore - # Check if the color channel is 3 (RGB) or 4 (RGBA) - if mode == "RGBA" and patch.shape[0] != 4: - raise ValueError( - f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." - ) - - elif mode in "RGB" and patch.shape[0] != 3: - raise ValueError( - f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[0]}. " - ) - return patch diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index f76e125fe0..b7365f50e3 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -31,7 +31,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -82,7 +82,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -150,7 +150,7 @@ def __init__( padding: int, mode: int = 0, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -235,7 +235,7 @@ def __init__( padding: int = 1, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 978695c5d0..b7f3921a47 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import warnings from typing import List, Optional, Tuple, Union @@ -96,22 +97,47 @@ class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1 """ - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name) self.ram_cost = 1 + in_channel / out_channel * 2 class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0): - super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + p3dmode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name) # 1 in_channel (activation) + 1 in_channel (convolution) + # 1 out_channel (convolution) + 1 out_channel (normalization) self.ram_cost = 2 + 2 * in_channel / out_channel class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization) # s0 = output_size/out_channel @@ -119,8 +145,15 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) @@ -182,14 +215,6 @@ class Cell(CellInterface): # \ # - Downsample - # Define connection operation set, parameterized by the number of channels - ConnOPS = { - "up": _FactorizedIncreaseBlockWithRAMCost, - "down": _FactorizedReduceBlockWithRAMCost, - "identity": _IdentityWithRAMCost, - "align_channels": _ActiConvNormBlockWithRAMCost, - } - # Define 2D operation set, parameterized by the number of channels OPS2D = { "skip_connect": lambda _c: _IdentityWithRAMCost(), @@ -205,18 +230,69 @@ class Cell(CellInterface): "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), } - def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3): + # Define connection operation set, parameterized by the number of channels + ConnOPS = { + "up": _FactorizedIncreaseBlockWithRAMCost, + "down": _FactorizedReduceBlockWithRAMCost, + "identity": _IdentityWithRAMCost, + "align_channels": _ActiConvNormBlockWithRAMCost, + } + + def __init__( + self, + c_prev: int, + c: int, + rate: int, + arch_code_c=None, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): super().__init__() self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name + if rate == -1: # downsample - self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["down"]( + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) elif rate == 1: # upsample - self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["up"]( + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) else: if c_prev == c: self.preprocess = self.ConnOPS["identity"]() else: - self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["align_channels"]( + c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) + + # Define 2D operation set, parameterized by the number of channels + self.OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name + ), + } + + # Define 3D operation set, parameterized by the number of channels + self.OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name + ), + } self.OPS = {} if self._spatial_dims == 2: @@ -283,7 +359,7 @@ def __init__( in_channels: int, num_classes: int, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), spatial_dims: int = 3, use_downsample: bool = True, node_a=None, @@ -398,7 +474,9 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + get_norm_layer( + name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)] + ), nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), ) @@ -484,6 +562,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -494,6 +574,8 @@ def __init__( self.num_blocks = num_blocks self.num_depths = num_depths self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name self.use_downsample = use_downsample self.device = device self.num_cell_ops = 0 @@ -535,6 +617,8 @@ def __init__( self.arch_code2ops[res_idx], self.arch_code_c[blk_idx, res_idx], self._spatial_dims, + self._act_name, + self._norm_name, ) def forward(self, x): @@ -555,6 +639,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -571,6 +657,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) @@ -591,7 +679,7 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) ) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out - inputs = outputs + inputs = outputs return inputs @@ -650,6 +738,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -663,6 +753,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 581e368ba0..c2385499b3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -311,6 +311,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -342,6 +343,9 @@ GridDistortiond, GridDistortionD, GridDistortionDict, + GridSplitd, + GridSplitD, + GridSplitDict, Orientationd, OrientationD, OrientationDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 37f1c3edc3..6b67762b95 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -18,6 +18,7 @@ import numpy as np import torch +from numpy.lib.stride_tricks import as_strided from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor @@ -65,6 +66,7 @@ "Orientation", "Flip", "GridDistortion", + "GridSplit", "Resize", "Rotate", "Zoom", @@ -2462,3 +2464,91 @@ def __call__( if not self._do_transform: return img return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) + + +class GridSplit(Transform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + + Example: + Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2), + it will return a Tensor or array with the size of (4, 3, 5, 5). + Here, if the `size` is provided, the returned shape will be (4, 3, size, size) + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None): + # Grid size + self.grid = grid + + # Patch size + self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) + + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + if self.grid == (1, 1) and self.size is None: + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) # type: ignore + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + size, steps = self._get_params(image.shape[1:]) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, size[0], steps[0]) + .unfold(2, size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + patches = as_strided( + image, + shape=(*self.grid, n_channels, size[0], size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + return patches + + def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): + """ + Calculate the size and step required for splitting the image + Args: + The size of the input image + """ + if self.size is not None: + # Set the split size to the given default size + if any(self.size[i] > image_size[i] for i in range(len(self.grid))): + raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") + split_size = self.size + else: + # infer each sub-image size from the image size and the grid + split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + + steps = tuple( + (image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + for i in range(len(self.grid)) + ) + + return split_size, steps diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d42a11fd2f..47fe05700e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -34,6 +34,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -129,6 +130,9 @@ "ZoomDict", "RandZoomD", "RandZoomDict", + "GridSplitd", + "GridSplitD", + "GridSplitDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -2149,6 +2153,40 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class GridSplitd(MapTransform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + keys: keys of the corresponding items to be transformed. + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + allow_missing_keys: don't raise exception if key is missing. + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = GridSplit.backend + + def __init__( + self, + keys: KeysCollection, + grid: Tuple[int, int] = (2, 2), + size: Optional[Union[int, Tuple[int, int]]] = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.splitter = GridSplit(grid=grid, size=size) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.splitter(d[key]) + return d + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2169,3 +2207,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd +GridSplitD = GridSplitDict = GridSplitd diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index d2c8a627b6..116f96126f 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -71,7 +71,7 @@ class TestAddExtremePointsChannel(unittest.TestCase): def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) - assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4, atol=1e-5) if __name__ == "__main__": diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index 39d221596f..f9837e9ef4 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -68,7 +68,7 @@ def test_correct_results(self, input_data, expected): keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 ) result = add_extreme_points_channel(input_data) - assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4, atol=1e-5) if __name__ == "__main__": diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 0f7d0f7d35..36aa7319f0 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import subprocess import tempfile @@ -17,6 +18,7 @@ from parameterized import parameterized from monai.bundle import ConfigParser +from monai.data import load_net_with_metadata from monai.networks import save_state from tests.utils import skip_if_windows @@ -33,7 +35,8 @@ def test_export(self, key_in_ckpt): config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg"} - def_args_file = os.path.join(tempdir, "def_args.json") + def_args_file = os.path.join(tempdir, "def_args.yaml") + ckpt_file = os.path.join(tempdir, "model.pt") ts_file = os.path.join(tempdir, "model.ts") @@ -44,11 +47,16 @@ def test_export(self, key_in_ckpt): save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] - cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] - cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] + cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"] + cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] subprocess.check_call(cmd) self.assertTrue(os.path.exists(ts_file)) + _, metadata, extra_files = load_net_with_metadata(ts_file, more_extra_files=["inference", "def_args"]) + self.assertTrue("schema" in metadata) + self.assertTrue("meta_file" in json.loads(extra_files["def_args"])) + self.assertTrue("network_def" in json.loads(extra_files["inference"])) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py new file mode 100644 index 0000000000..7e609a7b31 --- /dev/null +++ b/tests/test_bundle_download.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import tempfile +import unittest + +import torch +from parameterized import parameterized + +import monai.networks.nets as nets +from monai.apps import check_hash +from monai.bundle import ConfigParser, load +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick, skip_if_windows + +TEST_CASE_1 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_2 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_3 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "cuda" if torch.cuda.is_available() else "cpu", + "model.pt", +] + +TEST_CASE_4 = [ + ["test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "cuda" if torch.cuda.is_available() else "cpu", + "model.ts", +] + + +@skip_if_windows +class TestDownload(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + @skip_if_quick + def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val): + with skip_if_downloading_fails(): + # download a whole bundle from github releases + with tempfile.TemporaryDirectory() as tempdir: + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] + cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + if file == "network.json": + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + @parameterized.expand([TEST_CASE_2]) + @skip_if_quick + def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): + with skip_if_downloading_fails(): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + if file == "network.json": + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + +class TestLoad(unittest.TestCase): + @parameterized.expand([TEST_CASE_3]) + @skip_if_quick + def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): + with skip_if_downloading_fails(): + # download bundle, and load weights from the downloaded path + with tempfile.TemporaryDirectory() as tempdir: + # load weights + weights = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + repo=repo, + progress=False, + device=device, + ) + + # prepare network + with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: + net_args = json.load(f)["network_def"] + model_name = net_args["_target_"] + del net_args["_target_"] + model = nets.__dict__[model_name](**net_args) + model.to(device) + model.load_state_dict(weights) + model.eval() + + # prepare data and test + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device) + output = model.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device) + torch.testing.assert_allclose(output, expected_output) + + # load instantiated model directly and test, since the bundle has been downloaded, + # there is no need to input `repo` + model_2 = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + progress=False, + device=device, + net_name=model_name, + **net_args, + ) + model_2.eval() + output_2 = model_2.forward(input_tensor) + torch.testing.assert_allclose(output_2, expected_output) + + @parameterized.expand([TEST_CASE_4]) + @skip_if_quick + @SkipIfBeforePyTorchVersion((1, 7, 1)) + def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_file): + with skip_if_downloading_fails(): + # load ts module + with tempfile.TemporaryDirectory() as tempdir: + # load ts module + model_ts, metadata, extra_file_dict = load( + name=bundle_name, + model_file=model_file, + load_ts_module=True, + bundle_dir=tempdir, + repo=repo, + progress=False, + device=device, + config_files=("network.json",), + ) + + # prepare and test ts + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) + output = model_ts.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device) + torch.testing.assert_allclose(output, expected_output) + # test metadata + self.assertTrue(metadata["pytorch_version"] == "1.7.1") + # test extra_file_dict + self.assertTrue("network.json" in extra_file_dict.keys()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py index d480235b70..a5da39bae9 100644 --- a/tests/test_dints_cell.py +++ b/tests/test_dints_cell.py @@ -32,21 +32,28 @@ (2, 4, 64, 32, 16), ], [ - {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None}, + {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None, "act_name": "SELU", "norm_name": "BATCH"}, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([0, 0, 0, 1, 0]), (2, 8, 32, 16, 8), (2, 8, 32, 16, 8), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None}, + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([1, 1, 1, 1, 1]), (2, 8, 32, 16, 8), (2, 8, 16, 8, 4), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]}, + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "act_name": "RELU", "norm_name": "INSTANCE"}, torch.tensor([1, 0, 0, 0, 1]), torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]), (2, 8, 32, 16, 8), @@ -56,12 +63,35 @@ TEST_CASES_2D = [ [ - {"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2}, + { + "c_prev": 8, + "c": 7, + "rate": -1, + "arch_code_c": [1, 0, 0, 0, 1], + "spatial_dims": 2, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 0]), torch.tensor([0.2, 0.2]), (2, 8, 16, 8), (2, 7, 8, 4), - ] + ], + [ + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "spatial_dims": 2, + "act_name": "SELU", + "norm_name": "INSTANCE", + }, + torch.tensor([1, 0]), + torch.tensor([0.2, 0.2]), + (2, 8, 16, 8), + (2, 8, 8, 4), + ], ] diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py index 8be5eb7ccd..08e75fab98 100644 --- a/tests/test_dints_network.py +++ b/tests/test_dints_network.py @@ -33,7 +33,7 @@ "in_channels": 1, "num_classes": 3, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 3, }, @@ -101,7 +101,7 @@ "in_channels": 1, "num_classes": 4, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 2, }, diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py new file mode 100644 index 0000000000..6f0525029d --- /dev/null +++ b/tests/test_grid_split.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import GridSplit +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])] +TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] +TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] +TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] +TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid": (2, 2), "size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] +TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] +TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestGridSplit(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) + splitter = GridSplit(**input_parameters) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = GridSplit(**input_parameters) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py new file mode 100644 index 0000000000..f325a16946 --- /dev/null +++ b/tests/test_grid_splitd.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import GridSplitd +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])] +TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] +TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid": (2, 2), "size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [ + {"keys": "image", "grid": (2, 2)}, + [{"image": A}, {"image": A}], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] +TEST_CASE_MC_1 = [ + {"keys": "image", "grid": (2, 1)}, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, +] +TEST_CASE_MC_2 = [ + {"keys": "image", "grid": (1, 2)}, + [{"image": A1}, {"image": A2}], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestGridSplitd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + splitter = GridSplitd(**input_parameters) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = GridSplitd(**input_parameters) + for img_dict, expected in zip(img_list, expected_list): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 688f664089..f748eb8732 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -54,7 +54,7 @@ from monai.utils import set_determinism from monai.utils.enums import PostFix from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, pytorch_after, skip_if_quick TASK = "integration_workflows" @@ -149,7 +149,7 @@ def _forward_completed(self, engine): val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) train_postprocessing = Compose( @@ -205,7 +205,7 @@ def _model_completed(self, engine): amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) trainer.run() diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py index 6b4989a9d5..b135b3eaeb 100644 --- a/tests/test_lesion_froc.py +++ b/tests/test_lesion_froc.py @@ -31,7 +31,7 @@ def save_as_tif(filename, array): if not filename.endswith(".tif"): filename += ".tif" file_path = os.path.join("tests", "testing_data", filename) - imwrite(file_path, array, compress="jpeg", tile=(16, 16)) + imwrite(file_path, array, compression="jpeg", tile=(16, 16)) def around(val, interval=3): diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 6ee02143b8..3655100dab 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -84,7 +84,7 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -115,7 +115,7 @@ def save_gray_tiff(array: np.ndarray, filename: str): filename: the filename to be used for the tiff file. """ img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="rgb") + imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") return filename diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 456f5a9453..63d61dfeb3 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -72,7 +72,7 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -103,7 +103,7 @@ def save_gray_tiff(array: np.ndarray, filename: str): filename: the filename to be used for the tiff file. """ img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="rgb") + imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") return filename