From 7187c69223af734cd64c98d2f55490ca9bd469ac Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 14:15:20 -0700 Subject: [PATCH 01/11] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/requirements.txt | 1 + docs/source/installation.md | 4 +- monai/config/deviceconfig.py | 1 + monai/networks/nets/vltransformer.py | 355 +++++++++++++++++++++++++++ monai/transforms/utility/array.py | 77 ++++-- monai/utils/type_conversion.py | 128 +++++++--- requirements-dev.txt | 1 + setup.cfg | 3 + tests/min_tests.py | 1 + tests/test_to_tensor.py | 16 +- tests/test_vltransformer.py | 80 ++++++ 11 files changed, 601 insertions(+), 66 deletions(-) create mode 100644 monai/networks/nets/vltransformer.py create mode 100644 tests/test_vltransformer.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..3530d63c49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers==4.10.2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..902f596dfc 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` , `einops` and `transformers`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..ff45b29531 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,6 +73,7 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") return output diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py new file mode 100644 index 0000000000..af095a181c --- /dev/null +++ b/monai/networks/nets/vltransformer.py @@ -0,0 +1,355 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super(BertPreTrainedModel, self).__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + hidden_size, + ) -> None: + super(Pooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + bert_config: dict, # type: ignore + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class VLTransformers(torch.nn.Module): + """ + Vision Language Multimodal Transformers" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + drop_out: float = 0.0, + bert_config: dict = { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + bert_config: configuration for bert language transformer encoder. + Examples:: + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, + num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) + """ + super(VLTransformers, self).__init__() + + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be in the range of 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.embed_dim = 768 + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.norm_vision_pos = nn.LayerNorm(self.embed_dim) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) + self.pooler = Pooler(hidden_size=self.embed_dim) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index add47e27ca..824b5b33d3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,10 +32,18 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis -from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import +from monai.utils import ( + convert_data_type, + convert_to_numpy, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, + min_version, + optional_import, +) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -334,15 +342,16 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, device: Optional[torch.device] = None) -> None: + def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super().__init__() + self.dtype = dtype self.device = device def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=True) # type: ignore class EnsureType(Transform): @@ -354,19 +363,24 @@ class EnsureType(Transform): Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + ) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype + self.device = device - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -375,7 +389,12 @@ def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + if self.data_type == "tensor": + dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor) + return convert_to_tensor(data, dtype=dtype_, device=self.device) + else: + dtype_ = get_equivalent_dtype(self.dtype, np.ndarray) + return convert_to_numpy(data, dtype=dtype_) class ToNumpy(Transform): @@ -385,25 +404,36 @@ class ToNumpy(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dtype: Optional[DtypeLike] = None) -> None: + super().__init__() + self.dtype = dtype + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype) # type: ignore class ToCupy(Transform): """ Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. + + Args: + dtype: data type specifier. It is inferred from the input by default. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __init__(self, dtype=None) -> None: + super().__init__() + self.dtype = dtype + + def __call__(self, data: NdarrayOrTensor): """ - Apply the transform to `img` and make it contiguous. + Create a CuPy array from `data` and make it contiguous """ - return cp.ascontiguousarray(cp.asarray(img)) # type: ignore + return convert_to_cupy(data, self.dtype) class ToPIL(Transform): @@ -779,6 +809,9 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) @@ -828,6 +861,10 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + if output_shape is None: output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) @@ -848,6 +885,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = np.squeeze(img, axis=0) @@ -914,6 +952,9 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + # Generate extreme points self.randomize(label[0, :]) @@ -950,6 +991,7 @@ def __call__(self, img: torch.Tensor): img: PyTorch Tensor data for the TorchVision transform. """ + img, *_ = convert_data_type(img, torch.Tensor) # type: ignore return self.trans(img) @@ -980,7 +1022,7 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.dtype = dtype def __call__(self, img: np.ndarray): - img = np.asarray(img) + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_flat = img.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) @@ -1036,6 +1078,7 @@ def __call__( mask must have the same shape as input `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b51ff6a9c8..3636dbc6c0 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -16,6 +16,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", @@ -60,6 +61,8 @@ def get_equivalent_dtype(dtype, data_type): im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) """ + if dtype is None: + return None if data_type is torch.Tensor: if type(dtype) is torch.dtype: return dtype @@ -83,7 +86,12 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): +def convert_to_tensor( + data, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = False, +): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -92,36 +100,41 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch. data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): - return data.contiguous().to(device) + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) - elif has_cp and isinstance(data, cp_ndarray): - return torch.as_tensor(data, device=device) - elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data, device=device) - elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data, device=device) + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif ( + has_cp + and isinstance(data, cp_ndarray) + or isinstance(data, (float, int, bool)) + or (isinstance(data, Sequence) and wrap_sequence) + ): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): - return [convert_to_tensor(i, device=device) for i in data] + return [convert_to_tensor(i, dtype=dtype, device=device) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_tensor(i, device=device) for i in data) + return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) elif isinstance(data, dict): - return {k: convert_to_tensor(v, device=device) for k, v in data.items()} + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -130,23 +143,22 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to numpy array. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype) + elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + return [convert_to_numpy(i, dtype=dtype) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + return tuple(convert_to_numpy(i, dtype=dtype) for i in data) elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -154,6 +166,42 @@ def convert_to_numpy(data, wrap_sequence: bool = False): return data +def convert_to_cupy(data, dtype, wrap_sequence: bool = True): + """ + Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, + recursively check every item and convert it to cupy array. + + Args: + data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. + Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays + + for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. + """ + + # direct calls + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or ( + isinstance(data, Sequence) and wrap_sequence + ): + data = cp.asarray(data, dtype) + elif isinstance(data, list): + return [convert_to_cupy(i, dtype) for i in data] + elif isinstance(data, tuple): + return tuple(convert_to_cupy(i, dtype) for i in data) + elif isinstance(data, dict): + return {k: convert_to_cupy(v, dtype) for k, v in data.items()} + # make it contiguous + if isinstance(data, cp.ndarray): + if data.ndim > 0: + data = cp.ascontiguousarray(data) + else: + raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + + return data + + def convert_data_type( data: Any, output_type: Optional[type] = None, @@ -178,6 +226,8 @@ def convert_data_type( orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray + elif has_cp and isinstance(data, cp.ndarray): + orig_type = cp.ndarray else: orig_type = type(data) @@ -185,30 +235,27 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) + data = convert_to_tensor(data, dtype=dtype_, device=device) elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) + data = convert_to_numpy(data, dtype=dtype_) + elif has_cp and output_type is cp.ndarray: + data = convert_to_cupy(data, dtype=dtype_) else: raise ValueError(f"Unsupported output type: {output_type}") return data, orig_type, orig_device -def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: +def convert_to_dst_type( + src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None +) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ - If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, - if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, otherwise, convert to the type of `dst` directly. + `dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type. See Also: :func:`convert_data_type` @@ -217,6 +264,9 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor if isinstance(dst, torch.Tensor): device = dst.device + if dtype is None: + dtype = dst.dtype + output_type: Any if isinstance(dst, torch.Tensor): output_type = torch.Tensor @@ -224,4 +274,4 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor output_type = np.ndarray else: output_type = type(dst) - return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype) diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..ed8739ded8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ openslide-python==1.1.2 pandas requests einops +transformers diff --git a/setup.cfg b/setup.cfg index 6efe768a6f..f7ed90a14a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ all = openslide-python==1.1.2 pandas einops + transformers nibabel = nibabel skimage = @@ -74,6 +75,8 @@ pandas = pandas einops = einops +transformers = + transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..bac6521889 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", + "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 3d187a1dba..b065595e89 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,8 +11,8 @@ import unittest +import torch from parameterized import parameterized -from torch import Tensor from monai.transforms import ToTensor from tests.utils import TEST_NDARRAYS, assert_allclose, optional_import @@ -35,16 +35,16 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): - result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + result = ToTensor(dtype=torch.float32, device="cpu")(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) @unittest.skipUnless(has_cp, "CuPy is required.") @@ -52,8 +52,8 @@ def test_cupy(self): test_data = [[1, 2], [3, 4]] cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) result = ToTensor()(cupy_array) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py new file mode 100644 index 0000000000..a92a9bf79a --- /dev/null +++ b/tests/test_vltransformer.py @@ -0,0 +1,80 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vltransformer import VLTransformers + +TEST_CASE_VLTransformers = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), # type: ignore + ] + TEST_CASE_VLTransformers.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_VLTransformers) + def test_shape(self, input_param, expected_shape): + net = VLTransformers(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + VLTransformers( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + VLTransformers( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main() From bd3da3943a8e1b9917ee1e772c0389b513b20b50 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 16:03:24 -0700 Subject: [PATCH 02/11] add multimodal transformers Signed-off-by: ahatamizadeh --- tests/vltransformer.py | 355 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 355 insertions(+) create mode 100644 tests/vltransformer.py diff --git a/tests/vltransformer.py b/tests/vltransformer.py new file mode 100644 index 0000000000..af095a181c --- /dev/null +++ b/tests/vltransformer.py @@ -0,0 +1,355 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super(BertPreTrainedModel, self).__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + hidden_size, + ) -> None: + super(Pooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + bert_config: dict, # type: ignore + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class VLTransformers(torch.nn.Module): + """ + Vision Language Multimodal Transformers" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + drop_out: float = 0.0, + bert_config: dict = { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + bert_config: configuration for bert language transformer encoder. + Examples:: + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, + num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) + """ + super(VLTransformers, self).__init__() + + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be in the range of 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.embed_dim = 768 + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.norm_vision_pos = nn.LayerNorm(self.embed_dim) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) + self.pooler = Pooler(hidden_size=self.embed_dim) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits From f5c1406bc7dd211074c5079be15e7d35f3a5a25d Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 16:32:56 -0700 Subject: [PATCH 03/11] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/transforms/utility/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 824b5b33d3..9109fb04c5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -34,6 +34,7 @@ from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis from monai.utils import ( convert_data_type, + convert_to_cupy, convert_to_numpy, convert_to_tensor, ensure_tuple, From 09e693a4d5715cc3985e47c9991d49ef8ee79e08 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 19:35:58 -0700 Subject: [PATCH 04/11] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/requirements.txt | 1 - docs/source/installation.md | 4 +- monai/_extensions/loader.py | 2 +- monai/_version.py | 6 +- monai/config/deviceconfig.py | 1 - monai/data/grid_dataset.py | 2 +- monai/data/utils.py | 4 +- monai/networks/blocks/dynunet_block.py | 21 +- monai/networks/blocks/dynunet_block_v1.py | 10 + monai/networks/layers/spatial_transforms.py | 9 +- monai/networks/nets/autoencoder.py | 4 + monai/networks/nets/dynunet.py | 15 +- monai/networks/nets/dynunet_v1.py | 3 + monai/networks/nets/resnet.py | 39 +- monai/networks/nets/varautoencoder.py | 17 + monai/transforms/__init__.py | 2 +- monai/transforms/croppad/array.py | 20 + monai/transforms/croppad/dictionary.py | 5 +- monai/transforms/intensity/array.py | 67 +++- monai/transforms/intensity/dictionary.py | 21 +- monai/transforms/inverse_batch_transform.py | 2 +- monai/transforms/spatial/array.py | 223 ++++++----- monai/transforms/spatial/dictionary.py | 116 +++--- monai/transforms/utility/dictionary.py | 84 +++- .../utils_pytorch_numpy_unification.py | 69 ++++ monai/utils/__init__.py | 1 + monai/utils/aliases.py | 4 +- requirements-dev.txt | 1 - setup.cfg | 3 - tests/min_tests.py | 1 - tests/test_affine.py | 179 ++++++--- tests/test_affine_grid.py | 160 ++++---- tests/test_affined.py | 193 ++++++---- tests/test_as_channel_first.py | 2 +- tests/test_delete_itemsd.py | 25 +- tests/test_dynunet.py | 16 +- tests/test_ensure_type.py | 10 +- tests/test_ensure_typed.py | 15 +- tests/test_flip.py | 6 +- tests/test_flipd.py | 6 +- tests/test_inverse_collation.py | 7 +- tests/test_label_to_mask.py | 2 +- tests/test_label_to_maskd.py | 2 +- tests/test_normalize_intensity.py | 42 +- tests/test_normalize_intensityd.py | 10 +- tests/test_rand_affine.py | 222 ++++++----- tests/test_rand_affine_grid.py | 322 ++++++++-------- tests/test_rand_affined.py | 360 +++++++++--------- tests/test_rand_axis_flip.py | 6 +- tests/test_rand_axis_flipd.py | 6 +- tests/test_rand_elastic_2d.py | 151 ++++---- tests/test_rand_elastic_3d.py | 130 ++++--- tests/test_rand_elasticd_2d.py | 248 ++++++------ tests/test_rand_elasticd_3d.py | 212 ++++++----- tests/test_rand_flip.py | 6 +- tests/test_rand_flipd.py | 6 +- tests/test_rand_rotate.py | 91 +++-- tests/test_rand_rotate90.py | 24 +- tests/test_rand_rotate90d.py | 24 +- tests/test_rand_rotated.py | 164 ++++---- tests/test_rand_scale_intensity.py | 2 +- tests/test_rand_scale_intensityd.py | 4 +- tests/test_rand_shift_intensityd.py | 4 +- tests/test_rand_zoom.py | 10 +- tests/test_rand_zoomd.py | 10 +- tests/test_resampler.py | 181 ++++++--- tests/test_resnet.py | 14 +- tests/test_rotate.py | 77 ++-- tests/test_rotate90.py | 24 +- tests/test_rotate90d.py | 24 +- tests/test_rotated.py | 50 ++- tests/test_scale_intensity.py | 4 +- tests/test_scale_intensity_range.py | 13 +- .../test_scale_intensity_range_percentiles.py | 10 +- tests/test_scale_intensity_ranged.py | 13 +- tests/test_scale_intensityd.py | 8 +- tests/test_shift_intensityd.py | 2 +- tests/test_threshold_intensity.py | 19 +- tests/test_threshold_intensityd.py | 52 ++- tests/test_to_cupy.py | 54 ++- tests/test_to_numpy.py | 17 +- tests/test_to_numpyd.py | 8 +- tests/test_to_pil.py | 2 +- tests/test_to_pild.py | 6 +- tests/test_transpose.py | 2 +- tests/test_transposed.py | 6 +- tests/test_utils_pytorch_numpy_unification.py | 46 +++ tests/test_zoom.py | 2 +- tests/test_zoomd.py | 16 +- tests/utils.py | 37 +- tests/vltransformer.py | 355 ----------------- 91 files changed, 2421 insertions(+), 2055 deletions(-) create mode 100644 tests/test_utils_pytorch_numpy_unification.py delete mode 100644 tests/vltransformer.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 3530d63c49..00dd4d2c1e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,4 +20,3 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers==4.10.2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 902f596dfc..08ab109142 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` , `einops` and `transformers`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py index 5f77480ecc..6c68fe08c7 100644 --- a/monai/_extensions/loader.py +++ b/monai/_extensions/loader.py @@ -34,7 +34,7 @@ def timeout(time, message): except KeyboardInterrupt as e: if timer is not None and timer.is_alive(): raise e # interrupt from user? - raise TimeoutError(message) + raise TimeoutError(message) from e finally: if timer is not None: try: diff --git a/monai/_version.py b/monai/_version.py index 79f569dd79..fb3a60690e 100644 --- a/monai/_version.py +++ b/monai/_version.py @@ -23,9 +23,9 @@ def get_keywords(): # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" + git_refnames = " (HEAD -> dev, tag: 0.7.0rc1, releasing/0.7.0)" + git_full = "0f17aa991592fc6e635e86da3061b5dd3d669597" + git_date = "2021-09-15 21:03:08 +0000" keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index ff45b29531..273431fc72 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,7 +73,6 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") - output["transformers"] = get_package_version("transformers") return output diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 5b2a4d7abd..5c330f10e4 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -141,7 +141,7 @@ def __iter__(self): try: iter_end = len(self.dataset) # TODO: support iterable self.dataset except TypeError: - raise NotImplementedError("image dataset must implement `len()`.") + raise NotImplementedError("image dataset must implement `len()`.") from None if worker_info is not None: # split workload diff --git a/monai/data/utils.py b/monai/data/utils.py index aab23217dc..a5cb5057d4 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -283,7 +283,7 @@ def list_data_collate(batch: Sequence): + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation)." ) - raise RuntimeError(re_str) + raise RuntimeError(re_str) from re except TypeError as re: re_str = str(re) if "numpy" in re_str and "Tensor" in re_str: @@ -294,7 +294,7 @@ def list_data_collate(batch: Sequence): + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation)." ) - raise TypeError(re_str) + raise TypeError(re_str) from re def decollate_batch(batch, detach: bool = True): diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index bb654d841c..fc37fc8999 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -33,6 +33,7 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -44,6 +45,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetResBlock, self).__init__() self.conv1 = get_conv_layer( @@ -52,6 +54,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -60,6 +63,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.conv3 = get_conv_layer( @@ -68,6 +72,7 @@ def __init__( out_channels, kernel_size=1, stride=stride, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -107,6 +112,7 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -118,6 +124,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetBasicBlock, self).__init__() self.conv1 = get_conv_layer( @@ -126,6 +133,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -134,6 +142,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -164,6 +173,7 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -176,6 +186,7 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetUpBlock, self).__init__() upsample_stride = upsample_kernel_size @@ -185,6 +196,7 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, conv_only=True, is_transposed=True, ) @@ -194,6 +206,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, norm_name=norm_name, ) @@ -206,10 +219,12 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): super(UnetOutBlock, self).__init__() self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, bias=True, conv_only=True + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True ) def forward(self, inp): @@ -224,6 +239,7 @@ def get_conv_layer( stride: Union[Sequence[int], int] = 1, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, @@ -240,6 +256,7 @@ def get_conv_layer( kernel_size=kernel_size, act=act, norm=norm, + dropout=dropout, bias=bias, conv_only=conv_only, is_transposed=is_transposed, diff --git a/monai/networks/blocks/dynunet_block_v1.py b/monai/networks/blocks/dynunet_block_v1.py index d5d9bbf3dc..b5b88dd0df 100644 --- a/monai/networks/blocks/dynunet_block_v1.py +++ b/monai/networks/blocks/dynunet_block_v1.py @@ -32,6 +32,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) self.conv1 = get_conv_layer( @@ -40,6 +41,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -48,6 +50,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.conv3 = get_conv_layer( @@ -56,6 +59,7 @@ def __init__( out_channels, kernel_size=1, stride=stride, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -81,6 +85,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) self.conv1 = get_conv_layer( @@ -89,6 +94,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -97,6 +103,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -118,6 +125,7 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) upsample_stride = upsample_kernel_size @@ -127,6 +135,7 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, conv_only=True, is_transposed=True, ) @@ -137,6 +146,7 @@ def __init__( kernel_size=kernel_size, stride=1, norm_name=norm_name, + dropout=dropout, ) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 511c24fcb0..6b5acb166a 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -46,7 +46,9 @@ def backward(ctx, grad): return None, grads[0], None, None, None -def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): +def grid_pull( + input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True +) -> torch.Tensor: """ Sample an image with respect to a deformation field. @@ -112,8 +114,9 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in ensure_tuple(interpolation) ] - - return _GridPull.apply(input, grid, interpolation, bound, extrapolate) + out: torch.Tensor + out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) # type: ignore + return out class _GridPush(torch.autograd.Function): diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index 08b84d0566..ed5e351779 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -22,6 +22,10 @@ class AutoEncoder(nn.Module): + """ + Base class for the architecture implementing :py:class:`monai.networks.nets.VarAutoEncoder`. + """ + @deprecated_arg( name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." ) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4af70b22c7..d65cd9f5f4 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -86,6 +86,7 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the last feature @@ -115,6 +116,7 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), deep_supervision: bool = False, deep_supr_num: int = 1, @@ -128,6 +130,7 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() @@ -184,7 +187,7 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." - if not (len(kernels) == len(strides) and len(kernels) >= 3): + if len(kernels) != len(strides) or len(kernels) < 3: raise AssertionError(error_msg) for idx, k_i in enumerate(kernels): @@ -225,6 +228,7 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + dropout=self.dropout, ) def get_bottleneck(self): @@ -235,14 +239,11 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + dropout=self.dropout, ) def get_output_block(self, idx: int): - return UnetOutBlock( - self.spatial_dims, - self.filters[idx], - self.out_channels, - ) + return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels, dropout=self.dropout) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] @@ -276,6 +277,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "dropout": self.dropout, "upsample_kernel_size": up_kernel, } layer = conv_block(**params) @@ -289,6 +291,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "dropout": self.dropout, } layer = conv_block(**params) layers.append(layer) diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py index feb05d1762..c6a54807e4 100644 --- a/monai/networks/nets/dynunet_v1.py +++ b/monai/networks/nets/dynunet_v1.py @@ -38,6 +38,7 @@ class DynUNetV1(DynUNet): kernel_size: convolution kernel size. strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. + dropout: dropout ratio. Defaults to no dropout. norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. @@ -57,6 +58,7 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + dropout: float = 0.0, norm_name: str = "instance", deep_supervision: bool = False, deep_supr_num: int = 1, @@ -70,6 +72,7 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.dropout = dropout self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a5e6b7ab81..3b86dc3d62 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,9 +14,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_pool_layer +from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -162,7 +163,9 @@ class ResNet(nn.Module): conv1_t_size: size of first convolution layer, determines kernel and padding. conv1_t_stride: stride of first convolution layer. no_max_pool: bool argument to determine if to use maxpool layer. - shortcut_type: which downsample block to use. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. widen_factor: widen output for each layer. num_classes: number of output (classifications) """ @@ -198,7 +201,7 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride) + conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] @@ -209,7 +212,7 @@ def __init__( self.in_planes, kernel_size=conv1_kernel[spatial_dims], stride=conv1_stride[spatial_dims], - padding=con1_padding[spatial_dims], + padding=conv1_padding[spatial_dims], bias=False, ) self.bn1 = norm_type(self.in_planes) @@ -234,14 +237,9 @@ def __init__( nn.init.constant_(torch.as_tensor(m.bias), 0) def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor: - assert spatial_dims == 3 - out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)) - if isinstance(out.data, torch.FloatTensor): - zero_pads = zero_pads.cuda() - + out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x) + zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device) out = torch.cat([out.data, zero_pads], dim=1) - return out def _make_layer( @@ -259,9 +257,12 @@ def _make_layer( downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: - if shortcut_type == "A": + if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( - self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride + self._downsample_basic_block, + planes=planes * block.expansion, + stride=stride, + spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( @@ -269,12 +270,16 @@ def _make_layer( norm_type(planes * block.expansion), ) - layers = [] - layers.append( + layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, ) - ) + ] + self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 31e187106f..3baa59531a 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -28,6 +28,23 @@ class VarAutoEncoder(AutoEncoder): """ Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114 + .. code-block:: python + + from monai.networks.nets import VarAutoEncoder + + model = VarAutoEncoder( + dimensions=2, + in_shape=(32, 32), # image spatial shape + out_channels=1, + latent_size=2, + channels=(16, 32, 64), + strides=(1, 2, 2), + ) + + see also: + - Variational autoencoder network with MedNIST Dataset + https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb + .. deprecated:: 0.6.0 ``dimensions`` is deprecated, use ``spatial_dims`` instead. """ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a07dee867b..8a32b9e0b8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -524,4 +524,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis +from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7e3bc835dd..d3cec35d93 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -421,6 +421,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) sd = min(len(self.slices), len(img.shape[1:])) # spatial dims slices = [slice(None)] + self.slices[:sd] return img[tuple(slices)] @@ -449,6 +450,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) center = [i // 2 for i in img.shape[1:]] cropper = SpatialCrop(roi_center=center, roi_size=roi_size) @@ -469,6 +471,7 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -530,6 +533,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(img.shape[1:]) if self._size is None: raise AssertionError @@ -576,6 +580,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_size = img.shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -645,6 +650,7 @@ def __call__(self, img: np.ndarray) -> List[np.ndarray]: Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return [self.cropper(img) for _ in range(self.num_samples)] @@ -754,6 +760,8 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + box_start, box_end = self.compute_bounding_box(img) cropped = self.crop_pad(img, box_start, box_end, mode) @@ -799,12 +807,16 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> Returns: A list of image patches """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if weight_map is None: weight_map = self.weight_map if weight_map is None: raise ValueError("weight map must be provided for weighted patch sampling.") if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") + + weight_map, *_ = convert_data_type(weight_map, np.ndarray) # type: ignore + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) results = [] @@ -940,6 +952,9 @@ def __call__( if image is None: image = self.image + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -1073,6 +1088,9 @@ def __call__( if image is None: image = self.image + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + self.randomize(label, indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -1125,6 +1143,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return self.padder(self.cropper(img), mode=mode) @@ -1159,6 +1178,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore bbox = [] for channel in range(img.shape[0]): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 5c846b8d04..233f1b6edf 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -51,6 +51,7 @@ from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import InverseKeys +from monai.utils.type_conversion import convert_data_type __all__ = [ "PadModeSequence", @@ -848,7 +849,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) + img: np.ndarray + img, *_ = convert_data_type(d[self.source_key], np.ndarray) # type: ignore + box_start, box_end = self.cropper.compute_bounding_box(img=img) d[self.start_coord_key] = box_start d[self.end_coord_key] = box_end for key, m in self.key_iterator(d, self.mode): diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f6d4dfff5a..a1423c8ee5 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -28,6 +28,7 @@ from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -379,7 +380,11 @@ class ScaleIntensity(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None + self, + minv: Optional[float] = 0.0, + maxv: Optional[float] = 1.0, + factor: Optional[float] = None, + dtype: DtypeLike = np.float32, ) -> None: """ Args: @@ -387,10 +392,12 @@ def __init__( maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use this parameter, please set `minv` and `maxv` into None. + dtype: output data type, defaults to float32. """ self.minv = minv self.maxv = maxv self.factor = factor + self.dtype = dtype def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -401,10 +408,10 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ if self.minv is not None and self.maxv is not None: - return rescale_array(img, self.minv, self.maxv, img.dtype) + return rescale_array(img, self.minv, self.maxv, dtype=self.dtype) if self.factor is not None: out = img * (1 + self.factor) - out, *_ = convert_data_type(out, dtype=img.dtype) + out, *_ = convert_data_type(out, dtype=self.dtype) return out raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") @@ -417,12 +424,18 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__( + self, + factors: Union[Tuple[float, float], float], + prob: float = 0.1, + dtype: DtypeLike = np.float32, + ) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. + dtype: output data type, defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -433,6 +446,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) @@ -445,7 +459,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: self.randomize() if not self._do_transform: return img - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) + scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype) return scaler(img) @@ -517,6 +531,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(data=img) if not self._do_transform: return img @@ -642,6 +657,8 @@ class ThresholdIntensity(Transform): cval: value to fill the remaining parts of the image, default is 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): raise ValueError("threshold must be a float or int number.") @@ -649,13 +666,14 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + mask = img > self.threshold if self.above else img < self.threshold + res = where(mask, img, self.cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res class ScaleIntensityRange(Transform): @@ -671,6 +689,8 @@ class ScaleIntensityRange(Transform): clip: whether to perform clip after scaling. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None: self.a_min = a_min self.a_max = a_max @@ -678,7 +698,7 @@ def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: self.b_max = b_max self.clip = clip - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -689,7 +709,7 @@ def __call__(self, img: np.ndarray): img = (img - self.a_min) / (self.a_max - self.a_min) img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + img = clip(img, self.b_min, self.b_max) return img @@ -712,6 +732,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min @@ -754,6 +775,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if self.gamma_value is None: raise ValueError("gamma_value is not set.") @@ -818,6 +840,8 @@ class ScaleIntensityRangePercentiles(Transform): relative: whether to scale to the corresponding percentiles of [b_min, b_max]. """ + backend = ScaleIntensityRange.backend + def __init__( self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False ) -> None: @@ -832,12 +856,12 @@ def __init__( self.clip = clip self.relative = relative - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - a_min = np.percentile(img, self.lower) - a_max = np.percentile(img, self.upper) + a_min: float = percentile(img, self.lower) # type: ignore + a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min b_max = self.b_max @@ -849,7 +873,7 @@ def __call__(self, img: np.ndarray): img = scalar(img) if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + img = clip(img, self.b_min, self.b_max) return img @@ -889,10 +913,13 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore mask_data = self.mask_data if mask_data is None else mask_data if mask_data is None: raise ValueError("must provide the mask_data when initializing the transform or at runtime.") + mask_data, *_ = convert_data_type(mask_data, np.ndarray) # type: ignore + mask_data = np.asarray(self.select_fn(mask_data)) if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]: raise ValueError( @@ -915,7 +942,7 @@ class SavitzkyGolaySmooth(Transform): or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): @@ -979,6 +1006,7 @@ def __call__(self, img: np.ndarray): np.ndarray containing envelope of data in img along the specified axis. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform @@ -1005,6 +1033,7 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er self.approx = approx def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) return gaussian_filter(input_data).squeeze(0).detach().numpy() @@ -1049,6 +1078,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1096,6 +1126,7 @@ def __init__( self.approx = approx def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) @@ -1162,6 +1193,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1206,6 +1238,7 @@ def randomize(self, data: Optional[Any] = None) -> None: ) def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1692,6 +1725,7 @@ def _transform_holes(self, img: np.ndarray) -> np.ndarray: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(img.shape[1:]) if self._do_transform: img = self._transform_holes(img=img) @@ -1850,6 +1884,7 @@ def __init__( self.dtype = dtype def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return equalize_hist( img=img, mask=mask if mask is not None else self.mask, diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index ca24980359..07a6045870 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -488,6 +488,7 @@ def __init__( minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -498,11 +499,12 @@ def __init__( maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use this parameter, please set `minv` and `maxv` into None. + dtype: output data type, defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensity(minv, maxv, factor) + self.scaler = ScaleIntensity(minv, maxv, factor, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -523,6 +525,7 @@ def __init__( keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -533,6 +536,7 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + dtype: output data type, defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -546,6 +550,7 @@ def __init__( else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) @@ -556,7 +561,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.randomize() if not self._do_transform: return d - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) + scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype) for key in self.key_iterator(d): d[key] = scaler(d[key]) return d @@ -659,6 +664,8 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ThresholdIntensity.backend + def __init__( self, keys: KeysCollection, @@ -670,7 +677,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -692,6 +699,8 @@ class ScaleIntensityRanged(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRange.backend + def __init__( self, keys: KeysCollection, @@ -705,7 +714,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -809,6 +818,8 @@ class ScaleIntensityRangePercentilesd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRangePercentiles.backend + def __init__( self, keys: KeysCollection, @@ -823,7 +834,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index d9c6790840..b485e5bac4 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -99,7 +99,7 @@ def __call__(self, data: Dict[str, Any]) -> Any: re_str = str(re) if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." - raise RuntimeError(re_str) + raise RuntimeError(re_str) from re class Decollated(MapTransform): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a9cb847b93..c49f4e6479 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -46,6 +46,7 @@ issequenceiterable, optional_import, ) +from monai.utils.deprecated import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type @@ -170,6 +171,7 @@ def __call__( data_array (resampled into `self.pixdim`), original affine, current affine. """ + data_array, *_ = convert_data_type(data_array, np.ndarray) # type: ignore _dtype = dtype or self.dtype or data_array.dtype sr = data_array.ndim - 1 if sr <= 0: @@ -274,6 +276,7 @@ def __call__( data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. """ + data_array, *_ = convert_data_type(data_array, np.ndarray) # type: ignore sr = data_array.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") @@ -391,6 +394,7 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if self.size_mode == "all": input_ndim = img.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) @@ -441,6 +445,8 @@ class Rotate(Transform, ThreadUnsafe): the output data type is always ``np.float32``. """ + backend = [TransformBackends.TORCH] + def __init__( self, angle: Union[Sequence[float], float], @@ -448,7 +454,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: self.angle = angle self.keep_size = keep_size @@ -460,12 +466,12 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -488,7 +494,11 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -506,6 +516,9 @@ def __call__( shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(transform, img_t) # type: ignore + xform = AffineTransform( normalized=False, mode=look_up_option(mode or self.mode, GridSampleMode), @@ -513,13 +526,11 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape, - ) + output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).squeeze(0) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + out: NdarrayOrTensor + out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + return out def get_rotation_matrix(self) -> Optional[np.ndarray]: """ @@ -738,6 +749,8 @@ class RandRotate(RandomizableTransform): the output data type is always ``np.float32``. """ + backend = Rotate.backend + def __init__( self, range_x: Union[Tuple[float, float], float] = 0.0, @@ -748,7 +761,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -779,12 +792,12 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -802,7 +815,9 @@ def __call__( """ self.randomize() if not self._do_transform: - return img + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + return img_t rotator = Rotate( angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, @@ -811,7 +826,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return np.array(rotator(img)) + return rotator(img) class RandFlip(RandomizableTransform): @@ -1008,14 +1023,15 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. - as_tensor_output: whether to output tensor instead of numpy array, defaults to True. - device: device to store the output grid data. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -1024,23 +1040,20 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - self.affine = affine def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + grid: Optional[NdarrayOrTensor] = None, + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: spatial_size: output grid size. @@ -1056,7 +1069,7 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[torch.Tensor, np.ndarray] + affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -1071,17 +1084,13 @@ def __call__( else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + if self.device not in (None, torch.device("cpu"), "cpu"): + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) + grid, *_ = convert_data_type(grid, dtype=float) + affine, *_ = convert_to_dst_type(affine, grid) - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) - if grid is None or not isinstance(grid, torch.Tensor): - raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) + return grid, affine class RandAffineGrid(Randomizable, Transform): @@ -1090,6 +1099,9 @@ class RandAffineGrid(Randomizable, Transform): """ + backend = AffineGrid.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_range: RandRange = None, @@ -1123,8 +1135,6 @@ def __init__( scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1143,9 +1153,8 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[NdarrayOrTensor] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1167,8 +1176,8 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + grid: Optional[NdarrayOrTensor] = None, + ) -> NdarrayOrTensor: """ Args: spatial_size: output grid size. @@ -1183,13 +1192,13 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) - grid, self.affine = affine_grid(spatial_size, grid) - return grid + _grid: NdarrayOrTensor + _grid, self.affine = affine_grid(spatial_size, grid) + return _grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[NdarrayOrTensor]: """Get the most recently applied transformation matrix""" return self.affine @@ -1245,11 +1254,15 @@ def __call__(self, spatial_size: Sequence[int]): class Resample(Transform): + + backend = [TransformBackends.TORCH] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1263,21 +1276,19 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. device: device on which the tensor will be allocated. """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: NdarrayOrTensor, + grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]). @@ -1289,18 +1300,14 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) if grid is None: - raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + raise ValueError("Unknown grid.") + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, device=self.device, dtype=torch.float32) # type: ignore + grid, *_ = convert_to_dst_type(grid, img_t) if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] += (dim - 1.0) / 2.0 grid = grid[:-1] / grid[-1:] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) @@ -1315,29 +1322,29 @@ def __call__( bound = 1 _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value out = grid_pull( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid.unsqueeze(0), bound=bound, extrapolate=True, interpolation=1 if _interp_mode == "bilinear" else _interp_mode, )[0] else: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] = 2.0 * grid[i] / (dim - 1.0) grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) + index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) grid = grid[index_ordering] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid.unsqueeze(0), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + out_val: NdarrayOrTensor + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype) + return out_val class Affine(Transform): @@ -1347,6 +1354,9 @@ class Affine(Transform): """ + backend = list(set(AffineGrid.backend) & set(Resample.backend)) + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -1356,7 +1366,7 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1392,8 +1402,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). """ @@ -1402,22 +1410,21 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1447,6 +1454,9 @@ class RandAffine(RandomizableTransform): """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, @@ -1502,8 +1512,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1517,10 +1525,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1577,11 +1584,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1599,18 +1606,18 @@ def __call__( """ self.randomize() # if not doing transform and spatial size doesn't change, nothing to do - # except convert to float and convert numpy/torch + # except convert to float and device sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=torch.float32, device=self.resampler.device) grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) - return self.resampler( + out: NdarrayOrTensor = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + return out class Rand2DElastic(RandomizableTransform): @@ -1620,6 +1627,9 @@ class Rand2DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, spacing: Union[Tuple[float, float], float], @@ -1674,8 +1684,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1691,10 +1699,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) @@ -1715,11 +1722,11 @@ def randomize(self, spatial_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W), @@ -1748,7 +1755,10 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class Rand3DElastic(RandomizableTransform): @@ -1758,6 +1768,9 @@ class Rand3DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, sigma_range: Tuple[float, float], @@ -1815,8 +1828,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1824,8 +1835,14 @@ def __init__( - :py:class:`Affine` for the affine transformation parameters configurations. """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid( + rotate_range=rotate_range, + shear_range=shear_range, + translate_range=translate_range, + scale_range=scale_range, + device=device, + ) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range @@ -1855,11 +1872,11 @@ def randomize(self, grid_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W, D), @@ -1884,7 +1901,10 @@ def __call__( offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class AddCoordinateChannels(Transform): @@ -1914,6 +1934,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): Args: img: data to be transformed, assuming `img` is channel first. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if max(self.spatial_channels) > img.ndim - 1: raise ValueError( f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for " diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 96fe21db12..d794e51e80 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -55,8 +55,10 @@ ensure_tuple_rep, fall_back_tuple, ) +from monai.utils.deprecated import deprecated_arg from monai.utils.enums import InverseKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -573,6 +575,9 @@ class Affined(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -583,7 +588,7 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -620,8 +625,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -636,15 +639,12 @@ def __init__( translate_params=translate_params, scale_params=scale_params, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -661,7 +661,7 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -677,10 +677,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -693,6 +690,9 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -753,8 +753,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -772,7 +770,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -789,18 +786,19 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.rand_affine.randomize() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize() + device = self.rand_affine.resampler.device + sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - # to be consistent with the self._do_transform case (dtype and device) - affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) + affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64) + if device not in (None, torch.device("cpu"), "cpu"): + affine, *_ = convert_data_type(affine, torch.Tensor, device=device) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) @@ -821,23 +819,17 @@ def __call__( # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion - else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] + out: NdarrayOrTensor = d[key] else: orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform @@ -850,10 +842,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -866,6 +855,9 @@ class Rand2DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ + backend = Rand2DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -926,8 +918,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -946,7 +936,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -963,9 +952,7 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -995,6 +982,9 @@ class Rand3DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ + backend = Rand3DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -1057,8 +1047,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1077,7 +1065,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -1094,9 +1081,7 @@ def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) self.rand_3d_elastic.randomize(grid_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -1287,6 +1272,8 @@ class Rotated(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1295,7 +1282,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1306,7 +1293,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1333,7 +1320,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1351,12 +1338,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + img_t.unsqueeze(0), + transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE], ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + d[key] = output.squeeze(0).detach().float() # Remove the applied transform self.pop_transform(d, key) @@ -1398,6 +1390,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1409,7 +1403,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1440,14 +1434,11 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) - rotator = Rotate( - angle=angle, - keep_size=self.keep_size, - ) + rotator = Rotate(angle=angle, keep_size=self.keep_size) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): @@ -1476,7 +1467,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1496,12 +1487,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + output: torch.Tensor output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + img_t.unsqueeze(0), + transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE], ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + d[key] = output.squeeze(0).detach().float() # Remove the applied transform self.pop_transform(d, key) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9bcce93b0..a3c51fe3f2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,6 +17,7 @@ import copy import logging +import re from copy import deepcopy from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -442,15 +443,23 @@ class ToTensord(MapTransform, InvertibleTransform): backend = ToTensor.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data content type to convert, for example: torch.float, etc. + device: specify the target device to put the Tensor data. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToTensor() + self.converter = ToTensor(dtype=dtype, device=device) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -486,16 +495,25 @@ class EnsureTyped(MapTransform, InvertibleTransform): backend = EnsureType.backend - def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type) + self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -522,15 +540,21 @@ class ToNumpyd(MapTransform): backend = ToNumpy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[DtypeLike] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data type when converting to numpy array. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToNumpy() + self.converter = ToNumpy(dtype=dtype) def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) @@ -542,19 +566,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: class ToCupyd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: data type specifier. It is inferred from the input by default. + allow_missing_keys: don't raise exception if key is missing. """ backend = ToCupy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - allow_missing_keys: don't raise exception if key is missing. - """ + def __init__(self, keys: KeysCollection, dtype=None, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) - self.converter = ToCupy() + self.converter = ToCupy(dtype=dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -630,8 +654,38 @@ class DeleteItemsd(MapTransform): It will remove the key-values and copy the others to construct a new dictionary. """ + def __init__( + self, + keys: KeysCollection, + sep: str = ".", + use_re: Union[Sequence[bool], bool] = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to delete, can be "A{sep}B{sep}C" + to delete key `C` in nested dictionary, `C` can be regular expression. + See also: :py:class:`monai.transforms.compose.MapTransform` + sep: the separator tag to define nested dictionary keys, default to ".". + use_re: whether the specified key is a regular expression, it also can be + a list of bool values, map the to keys. + """ + super().__init__(keys) + self.sep = sep + self.use_re = ensure_tuple_rep(use_re, len(self.keys)) + def __call__(self, data): - return {key: val for key, val in data.items() if key not in self.key_iterator(data)} + def _delete_item(keys, d, use_re: bool = False): + key = keys[0] + if len(keys) > 1: + d[key] = _delete_item(keys[1:], d[key], use_re) + return d + return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)} + + d = dict(data) + for key, use_re in zip(self.keys, self.use_re): + d = _delete_item(key.split(self.sep), d, use_re) + + return d class SelectItemsd(MapTransform): diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2eebe3eda3..0fb8e34ef0 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import numpy as np import torch @@ -17,6 +19,9 @@ __all__ = [ "moveaxis", "in1d", + "clip", + "percentile", + "where", ] @@ -50,3 +55,67 @@ def in1d(x, y): if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + + +def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: + """`np.clip` with equivalent implementation for torch.""" + result: NdarrayOrTensor + if isinstance(a, np.ndarray): + result = np.clip(a, a_min, a_max) + else: + result = torch.clip(a, a_min, a_max) + return result + + +def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: + """`np.percentile` with equivalent implementation for torch. + + Pytorch uses `quantile`, but this functionality is only available from v1.7. + For earlier methods, we calculate it ourselves. This doesn't do interpolation, + so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + + Args: + x: input data + q: percentile to compute (should in range 0 <= q <= 100) + + Returns: + Resulting value (scalar) + """ + if np.isscalar(q): + if not 0 <= q <= 100: + raise ValueError + else: + if any(q < 0) or any(q > 100): + raise ValueError + result: Union[NdarrayOrTensor, float, int] + if isinstance(x, np.ndarray): + result = np.percentile(x, q) + else: + q = torch.tensor(q, device=x.device) + if hasattr(torch, "quantile"): + result = torch.quantile(x, q / 100.0) + else: + # Note that ``kthvalue()`` works one-based, i.e., the first sorted value + # corresponds to k=1, not k=0. Thus, we need the `1 +`. + k = 1 + (0.01 * q * (x.numel() - 1)).round().int() + if k.numel() > 1: + r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] + result = torch.tensor(r, device=x.device) + else: + result = x.view(-1).kthvalue(int(k)).values.item() + + return result + + +def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: + """ + Note that `torch.where` may convert y.dtype to x.dtype. + """ + result: NdarrayOrTensor + if isinstance(condition, np.ndarray): + result = np.where(condition, x, y) + else: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + return result diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index aa8f02f815..dc3922933d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -77,6 +77,7 @@ from .state_cacher import StateCacher from .type_conversion import ( convert_data_type, + convert_to_cupy, convert_to_dst_type, convert_to_numpy, convert_to_tensor, diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py index 2b7b29eeb5..a08dab4f95 100644 --- a/monai/utils/aliases.py +++ b/monai/utils/aliases.py @@ -70,8 +70,8 @@ def resolve_name(name): try: mod = importlib.import_module(modname) obj = getattr(mod, declname, None) - except ModuleNotFoundError: - raise ValueError(f"Module {modname!r} not found.") + except ModuleNotFoundError as not_found_err: + raise ValueError(f"Module {modname!r} not found.") from not_found_err if obj is None: raise ValueError(f"Module {modname!r} does not have member {declname!r}.") diff --git a/requirements-dev.txt b/requirements-dev.txt index ed8739ded8..785454ad5d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,4 +36,3 @@ openslide-python==1.1.2 pandas requests einops -transformers diff --git a/setup.cfg b/setup.cfg index f7ed90a14a..6efe768a6f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,6 @@ all = openslide-python==1.1.2 pandas einops - transformers nibabel = nibabel skimage = @@ -75,8 +74,6 @@ pandas = pandas einops = einops -transformers = - transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/min_tests.py b/tests/min_tests.py index bac6521889..5b376d7b57 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,7 +140,6 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", - "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..bd89f1a436 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -16,78 +16,139 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..972cf20a1f 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -16,88 +16,106 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + { + "rotate_params": (1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "rotate_params": (1.0, 1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), ] - ), - ], -] + ) class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..142cedc8d9 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -16,85 +16,142 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4, 4), + device=device, + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 0d1b1c7d3a..918e576011 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape): if isinstance(test_data, torch.Tensor): test_data = test_data.cpu().numpy() expected = np.moveaxis(test_data, input_param["channel_dim"], 0) - assert_allclose(expected, result) + assert_allclose(result, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index 7426e39ff0..b7cd104c46 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -19,19 +19,36 @@ TEST_CASE_1 = [{"keys": [str(i) for i in range(30)]}, 20] +TEST_CASE_2 = [{"keys": ["image/" + str(i) for i in range(30)], "sep": "/"}, 20] + +TEST_CASE_3 = [{"keys": "meta_dict%0008\\|[0-9]", "sep": "%", "use_re": True}] + class TestDeleteItemsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_memory(self, input_param, expected_key_size): - input_data = {} + input_data = {"image": {}} if "sep" in input_param else {} for i in range(50): - input_data[str(i)] = [time.time()] * 100000 + if "sep" in input_param: + input_data["image"][str(i)] = [time.time()] * 100000 + else: + input_data[str(i)] = [time.time()] * 100000 result = DeleteItemsd(**input_param)(input_data) - self.assertEqual(len(result.keys()), expected_key_size) + if "sep" in input_param: + self.assertEqual(len(result["image"].keys()), expected_key_size) + else: + self.assertEqual(len(result.keys()), expected_key_size) self.assertGreaterEqual( sys.getsizeof(input_data) * float(expected_key_size) / len(input_data), sys.getsizeof(result) ) + @parameterized.expand([TEST_CASE_3]) + def test_re(self, input_param): + input_data = {"image": [1, 2, 3], "meta_dict": {"0008|0005": 1, "0008|1050": 2, "0008test": 3}} + result = DeleteItemsd(**input_param)(input_data) + self.assertEqual(result["meta_dict"]["0008test"], 3) + self.assertTrue(len(result["meta_dict"]), 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 81ed239461..18fe146a40 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -26,14 +26,14 @@ expected_shape: Sequence[Any] TEST_CASE_DYNUNET_2D = [] +out_channels = 2 +in_size = 64 +spatial_dims = 2 for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: + expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) for in_channels in [2, 3]: for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) test_case = [ { "spatial_dims": spatial_dims, @@ -45,6 +45,7 @@ "norm_name": "batch", "deep_supervision": False, "res_block": res_block, + "dropout": None, }, (1, in_channels, in_size, in_size), expected_shape, @@ -52,11 +53,11 @@ TEST_CASE_DYNUNET_2D.append(test_case) TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides +in_channels = 1 +in_size = 64 for out_channels in [2, 3]: + expected_shape = (1, out_channels, 64, 32, 64) for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) test_case = [ { "spatial_dims": 3, @@ -68,6 +69,7 @@ "norm_name": ("INSTANCE", {"affine": True}), "deep_supervision": False, "res_block": res_block, + "dropout": ("alphadropout", {"p": 0.25}), }, (1, in_channels, in_size, in_size, in_size), expected_shape, diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 8feb96ed37..64094b2360 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -25,9 +25,11 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")(test_data) + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -36,12 +38,12 @@ def test_single_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(data_type=dtype, device="cpu")(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 96f482afc2..a78df6cb3f 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -25,9 +25,16 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped( + keys="data", + data_type=dtype, + dtype=np.float32 if dtype == "NUMPY" else None, + device="cpu", + )({"data": test_data})["data"] + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -41,7 +48,7 @@ def test_single_input(self): if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): @@ -75,7 +82,7 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] self.assertTrue(isinstance(result, dict)) self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0])) diff --git a/tests/test_flip.py b/tests/test_flip.py index 404a3def7d..8547f8aeb4 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = Flip(spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 1676723800..2fa783f8ad 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = Flipd(keys="img", spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c5dd9f1210..bc0fc3ff1b 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -61,7 +61,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -85,7 +84,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -117,10 +115,7 @@ def tearDown(self): @parameterized.expand(TESTS_2D + TESTS_3D) def test_collation(self, _, transform, collate_fn, ndim): - if ndim == 3: - data = self.data_3d - else: - data = self.data_2d + data = self.data_3d if ndim == 3 else self.data_2d if collate_fn: modified_transform = transform else: diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 9caa7252f3..6c8f935fbc 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -64,7 +64,7 @@ def test_value(self, argments, image, expected_data): self.assertEqual(type(result), type(image)) if isinstance(result, torch.Tensor): self.assertEqual(result.device, image.device) - assert_allclose(result, expected_data) + assert_allclose(result, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index b8f0d3c171..b2073e8ac3 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -65,7 +65,7 @@ def test_value(self, argments, input_data, expected_data): self.assertEqual(type(r), type(i)) if isinstance(r, torch.Tensor): self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data) + assert_allclose(r, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 2755eb4c25..41c6b053ec 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -31,51 +31,51 @@ "divisor": u(np.array([0.5, 0.5, 0.5, 0.5])), "nonzero": True, }, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, 3.0, 0.0, 4.0])), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) - TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": True}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([1, 1, 1, 1])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) TESTS.append( [ p, {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]])), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]])), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * -1.0), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) @@ -91,17 +91,14 @@ def test_default(self, im_type): self.assertEqual(im.device, normalized.device) self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - assert_allclose(expected, normalized, rtol=1e-3) + assert_allclose(normalized, expected, type_test=False, rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) im = in_type(input_data) normalized = normalizer(im) - self.assertEqual(type(im), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(im.device, normalized.device) - assert_allclose(expected_data, normalized) + assert_allclose(normalized, in_type(expected_data)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, im_type): @@ -109,10 +106,7 @@ def test_channel_wise(self, im_type): input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) normalized = normalizer(input_data) - self.assertEqual(type(input_data), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(input_data.device, normalized.device) - assert_allclose(expected, normalized) + assert_allclose(normalized, im_type(expected)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value_errors(self, im_type): diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index e2cec5407a..60b1d05456 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -25,7 +25,7 @@ [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( @@ -37,14 +37,14 @@ "nonzero": True, }, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, - np.array([0.0, 0.0, 0.0, 0.0]), + p(np.array([0.0, 0.0, 0.0, 0.0])), ] ) @@ -60,7 +60,7 @@ def test_image_normalize_intensityd(self, im_type): self.assertEqual(type(im), type(normalized)) if isinstance(normalized, torch.Tensor): self.assertEqual(im.device, normalized.device) - assert_allclose(normalized, expected, rtol=1e-3) + assert_allclose(normalized, im_type(expected), rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): @@ -82,7 +82,7 @@ def test_channel_wise(self, im_type): if isinstance(normalized, torch.Tensor): self.assertEqual(input_data[key].device, normalized.device) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - assert_allclose(normalized, expected) + assert_allclose(normalized, im_type(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..c88aa538ed 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -16,114 +16,132 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.type_conversion import convert_data_type +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +150,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..4fb534aba1 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -16,182 +16,192 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..bec9602d62 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -17,179 +17,188 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,23 +209,16 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic RandAffined( - as_tensor_output=False, - device=None, + device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index c05c3a1e0d..1772ef4987 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -22,10 +22,8 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) result = flip(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 7bef0baa63..37a17db69f 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -23,10 +23,8 @@ def test_correct_results(self): flip = RandAxisFlipd(keys="img", prob=1.0) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..c414eb1ffd 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -16,90 +16,101 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..d44324746f 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -16,69 +16,89 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + }, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..84f18120e1 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -16,127 +16,147 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, + ] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +164,7 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..5f8a5f47ed 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,7 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index b3c514cb1f..df49d60861 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 8972024fd8..c2869537cb 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -26,11 +26,9 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..4817e81735 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index f339158f94..9fc025fbbe 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -23,44 +23,36 @@ def test_default(self): for p in TEST_NDARRAYS: rotate.set_random_state(123) rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index f9083afb0c..3071aa82c8 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -24,11 +24,9 @@ def test_default(self): for p in TEST_NDARRAYS: rotate.set_random_state(123) rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = "test" @@ -36,11 +34,9 @@ def test_k(self): for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" @@ -48,11 +44,9 @@ def test_spatial_axes(self): for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" @@ -60,11 +54,9 @@ def test_prob_k_spatial_axes(self): for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..4c9a27f668 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -40,7 +118,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -53,70 +131,16 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -129,7 +153,7 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 750d88bfad..b863e2f874 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -25,7 +25,7 @@ def test_value(self): result = scaler(p(self.imt)) np.random.seed(0) expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index a8d2e63f65..fdcbd7146a 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -19,14 +19,14 @@ class TestRandScaleIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0) scaler.set_random_state(seed=0) result = scaler({key: p(self.imt)}) np.random.seed(0) expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 6766236146..c5dfb66722 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -19,14 +19,14 @@ class TestRandShiftIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0) shifter.set_random_state(seed=0) result = shifter({key: p(self.imt)}) np.random.seed(0) expected = self.imt + np.random.uniform(low=-1.0, high=1.0) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 0ac1b92c39..6ccb265cca 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -35,11 +35,13 @@ def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): ) random_zoom.set_random_state(1234) zoomed = random_zoom(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = [ + zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed, expected, atol=1.0) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): for p in TEST_NDARRAYS: diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index fafaf748bd..842d207ca6 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -38,11 +38,13 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz random_zoom.set_random_state(1234) zoomed = random_zoom({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = [ + zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - assert_allclose(expected, zoomed[key], atol=1.0) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..af23421ecc 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -17,69 +17,146 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ + ) + TESTS.append( [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ) + ), ] - ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q( + np.array( + [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data["device"]) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c4ba5c2e16..16cd6f4865 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -42,14 +42,26 @@ (2, 3), ] +TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"}, + (2, 1, 32, 64), + (2, 3), +] + TEST_CASE_3 = [ # 1D, batch 1, 2 input channels {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, (1, 2, 32), (1, 3), ] +TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"}, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..16a9c6d124 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -70,15 +72,16 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -103,23 +106,25 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 03a967a16b..9857b26fe8 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -22,41 +22,33 @@ def test_rotate90_default(self): rotate = Rotate90() for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) + expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index a1fa3c977c..a2a4a27521 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -23,44 +23,36 @@ def test_rotate90_default(self): rotate = Rotate90d(keys=key) for p in TEST_NDARRAYS: rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) for p in TEST_NDARRAYS: rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotated = rotate({key: p(self.imt[0])}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..cd27dd5406 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -10,36 +10,38 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +54,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +68,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +83,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -91,9 +97,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +112,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels") diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index c2485af616..24c6900ba5 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -26,14 +26,14 @@ def test_range_scale(self): maxa = self.imt.max() norm = (self.imt - mina) / (maxa - mina) expected = p((norm * (2.0 - 1.0)) + 1.0) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, expected, type_test=False, rtol=1e-7, atol=0) def test_factor_scale(self): for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) result = scaler(p(self.imt)) expected = p((self.imt * (1 + 0.1)).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index cba07d9157..d06bfd3596 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -11,19 +11,18 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRange -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRange(NumpyImageTestCase2D): def test_image_scale_intensity_range(self): scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler(self.imt) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled, expected)) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled, p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 015162c8de..5ba3a1e1ee 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms.intensity.array import ScaleIntensityRangePercentiles -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): @@ -30,7 +30,9 @@ def test_scaling(self): expected = (img - a_min) / (a_max - a_min) expected = (expected * (b_max - b_min)) + b_min scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max) - self.assertTrue(np.allclose(expected, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected)) def test_relative_scaling(self): img = self.imt @@ -47,7 +49,9 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected_img)) def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255) diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index a8cac414e8..dc064a7708 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -11,20 +11,19 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRanged -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRanged(NumpyImageTestCase2D): def test_image_scale_intensity_ranged(self): key = "img" scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler({key: self.imt}) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled[key], expected)) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 6e13dbc272..ce298f20af 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -19,23 +19,23 @@ class TestScaleIntensityd(NumpyImageTestCase2D): def test_range_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0) result = scaler({key: p(self.imt)}) mina = np.min(self.imt) maxa = np.max(self.imt) norm = (self.imt - mina) / (maxa - mina) expected = (norm * (2.0 - 1.0)) + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1) result = scaler({key: p(self.imt)}) expected = (self.imt * (1 + 0.1)).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 0396857781..66aad23b1e 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -24,7 +24,7 @@ def test_value(self): shifter = ShiftIntensityd(keys=[key], offset=1.0) result = shifter({key: p(self.imt)}) expected = self.imt + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..075a650ec0 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -15,20 +15,21 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) - np.testing.assert_allclose(result, expected_value) + assert_allclose(result, in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..a2a9fdcf2b 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -15,31 +15,41 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) - np.testing.assert_allclose(result["image"], expected_value) - np.testing.assert_allclose(result["label"], expected_value) - np.testing.assert_allclose(result["extra"], expected_value) + assert_allclose(result["image"], in_type(expected_value)) + assert_allclose(result["label"], in_type(expected_value)) + assert_allclose(result["extra"], in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 8b00e12539..0fd9607339 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -22,49 +22,81 @@ cp, has_cp = optional_import("cupy") +@skipUnless(has_cp, "CuPy is required.") class TestToCupy(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") def test_cupy_input(self): - test_data = cp.array([[1, 2], [3, 4]]) + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_cupy_input_dtype(self): + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) + test_data = cp.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(cp.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_numpy_input(self): - test_data = np.array([[1, 2], [3, 4]]) + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_numpy_input_dtype(self): + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) + test_data = np.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(np.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_tensor_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]) + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.numpy()) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") @skip_if_no_cuda def test_tensor_cuda_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda() test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.cpu().numpy()) + cp.testing.assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input_dtype(self): + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + + result = ToCupy(dtype="float32")(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToCupy()(test_data) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index b48727c01d..c7631540b8 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -31,16 +31,17 @@ def test_cupy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) - result = ToNumpy()(test_data) + result = ToNumpy(dtype="float32")(test_data) self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.dtype == np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -49,7 +50,7 @@ def test_tensor_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,21 +60,21 @@ def test_tensor_cuda_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) test_data = ((1, 2), (3, 4)) result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) self.assertEqual(result.ndim, 0) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 5acaef39c7..0b0b032ef2 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -31,7 +31,7 @@ def test_cupy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -40,7 +40,7 @@ def test_numpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -49,7 +49,7 @@ def test_tensor_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,7 +59,7 @@ def test_tensor_cuda_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index 5690645dd8..b4581053c0 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -43,7 +43,7 @@ class TestToPIL(unittest.TestCase): def test_value(self, test_data): result = ToPIL()(test_data) self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data) + assert_allclose(np.array(result), test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 3a15b1e507..3b83fa5258 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -30,9 +30,7 @@ PILImageImage, _ = optional_import("PIL.Image", name="Image") im = [[1.0, 2.0], [3.0, 4.0]] -TESTS = [] -for p in TEST_NDARRAYS: - TESTS.append([{"keys": "image"}, {"image": p(im)}]) +TESTS = [[{"keys": "image"}, {"image": p(im)}] for p in TEST_NDARRAYS] if has_pil: TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) @@ -43,7 +41,7 @@ class TestToPIL(unittest.TestCase): def test_values(self, input_param, test_data): result = ToPILd(**input_param)(test_data)[input_param["keys"]] self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data[input_param["keys"]]) + assert_allclose(np.array(result), test_data[input_param["keys"]], type_test=False) if __name__ == "__main__": diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 10882c9dd8..16cca49e1c 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -42,7 +42,7 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out2 = np.transpose(im, indices) - assert_allclose(out1, out2) + assert_allclose(out1, out2, type_test=False) if __name__ == "__main__": diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 88ecd0c872..2f9558b74e 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -57,13 +57,13 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out_gt = np.transpose(im, indices) - assert_allclose(out_im1, out_gt) - assert_allclose(out_im2, out_gt) + assert_allclose(out_im1, out_gt, type_test=False) + assert_allclose(out_im2, out_gt, type_test=False) # test inverse fwd_inv_data = tr.inverse(out_data) for i, j in zip(data.values(), fwd_inv_data.values()): - assert_allclose(i, j) + assert_allclose(i, j, type_test=False) if __name__ == "__main__": diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..c8e0a35c92 --- /dev/null +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms.utils_pytorch_numpy_unification import percentile +from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism + + +class TestPytorchNumpyUnification(unittest.TestCase): + def setUp(self) -> None: + set_determinism(0) + + def test_percentile(self): + for size in (1, 100): + q = np.random.randint(0, 100, size=size) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + results.append(percentile(arr, q)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + + def test_fails(self): + for p in TEST_NDARRAYS: + for q in (-1, 101): + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + with self.assertRaises(ValueError): + percentile(arr, q) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zoom.py b/tests/test_zoom.py index a99e110052..9411988a7e 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -37,7 +37,7 @@ def test_correct_results(self, zoom, mode): for channel in self.imt[0]: expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed, expected, atol=1.0) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): for p in TEST_NDARRAYS: diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 1ebd7d2d08..6231978ca7 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -27,22 +27,18 @@ class TestZoomd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode, keep_size): key = "img" - zoom_fn = Zoomd( - key, - zoom=zoom, - mode=mode, - keep_size=keep_size, - ) + zoom_fn = Zoomd(key, zoom=zoom, mode=mode, keep_size=keep_size) for p in TEST_NDARRAYS: zoomed = zoom_fn({key: p(self.imt[0])}) _order = 0 if mode.endswith("linear"): _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = [ + zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False) for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - assert_allclose(expected, zoomed[key], atol=1.0) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" diff --git a/tests/utils.py b/tests/utils.py index 1375cd2d72..6b7f6c4c16 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,24 +57,45 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): +def assert_allclose( + actual: NdarrayOrTensor, + desired: NdarrayOrTensor, + type_test: bool = True, + device_test: bool = False, + *args, + **kwargs, +): """ - Assert that all values of two data objects are close. + Assert that types and all values of two data objects are close. Args: - a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison - b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against + actual: Pytorch Tensor or numpy array for comparison. + desired: Pytorch Tensor or numpy array to compare against. + type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors. + device_test: whether to test the device property. + args: extra arguments to pass on to `np.testing.assert_allclose`. + kwargs: extra arguments to pass on to `np.testing.assert_allclose`. + + """ - a = a.cpu() if isinstance(a, torch.Tensor) else a - b = b.cpu() if isinstance(b, torch.Tensor) else b - np.testing.assert_allclose(a, b, *args, **kwargs) + if type_test: + # check both actual and desired are of the same type + np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray), "numpy type") + np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor), "torch type") + + if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): + if device_test: + np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore + actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.cpu().numpy() if isinstance(desired, torch.Tensor) else desired + np.testing.assert_allclose(actual, desired, *args, **kwargs) def test_pretrained_networks(network, input_param, device): try: net = network(**input_param).to(device) except (URLError, HTTPError, ContentTooShortError) as e: - raise unittest.SkipTest(e) + raise unittest.SkipTest(e) from e return net diff --git a/tests/vltransformer.py b/tests/vltransformer.py deleted file mode 100644 index af095a181c..0000000000 --- a/tests/vltransformer.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import os -import shutil -import tarfile -import tempfile -from typing import Sequence, Union - -import torch -from torch import nn - -from monai.utils import optional_import - -transformers = optional_import("transformers") -load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") -cached_path = optional_import("transformers.file_utils", name="cached_path")[0] -BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] -BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] - - -class BertPreTrainedModel(nn.Module): - """Module to load BERT pre-trained weights. - Based on: - LXMERT - https://github.com/airsplay/lxmert - BERT (pytorch-transformer) - https://github.com/huggingface/transformers - """ - - def __init__(self, *inputs, **kwargs) -> None: - super(BertPreTrainedModel, self).__init__() - - def init_bert_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, torch.nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained( - cls, - num_language_layers, - num_vision_layers, - num_mixed_layers, - bert_config, - state_dict=None, - cache_dir=None, - from_tf=False, - *inputs, - **kwargs, - ): - archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - tempdir = tempfile.mkdtemp() - with tarfile.open(resolved_archive_file, "r:gz") as archive: - archive.extractall(tempdir) - serialization_dir = tempdir - model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) - if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, "pytorch_model.bin") - state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) - if tempdir: - shutil.rmtree(tempdir) - if from_tf: - weights_path = os.path.join(serialization_dir, "model.ckpt") - return load_tf_weights_in_bert(model, weights_path) - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - new_key = key.replace("gamma", "weight") - if "beta" in key: - new_key = key.replace("beta", "bias") - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs - ) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - start_prefix = "" - if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): - start_prefix = "bert." - load(model, prefix=start_prefix) - return model - - -class BertAttention(nn.Module): - """BERT attention layer. - Based on: BERT (pytorch-transformer) - https://github.com/huggingface/transformers - """ - - def __init__( - self, - config, - ) -> None: - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, context): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(context) - mixed_value_layer = self.value(context) - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer - - -class BertOutput(nn.Module): - """BERT output layer. - Based on: BERT (pytorch-transformer) - https://github.com/huggingface/transformers - """ - - def __init__(self, config) -> None: - super(BertOutput, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertMixedLayer(nn.Module): - """BERT cross attention layer. - Based on: BERT (pytorch-transformer) - https://github.com/huggingface/transformers - """ - - def __init__( - self, - config, - ) -> None: - super().__init__() - self.att = BertAttention(config) - self.output = BertOutput(config) - - def forward(self, x, y): - output = self.att(x, y) - return self.output(output, x) - - -class Pooler(nn.Module): - """BERT pooler layer. - Based on: BERT (pytorch-transformer) - https://github.com/huggingface/transformers - """ - - def __init__( - self, - hidden_size, - ) -> None: - super(Pooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class MultiModal(BertPreTrainedModel): - """ - Multimodal Transformers From Pretrained BERT Weights" - """ - - def __init__( - self, - num_language_layers: int, - num_vision_layers: int, - num_mixed_layers: int, - bert_config: dict, # type: ignore - ) -> None: - """ - Args: - num_language_layers: number of language transformer layers. - num_vision_layers: number of vision transformer layers. - bert_config: configuration for bert language transformer encoder. - - """ - super().__init__() - self.config = type("obj", (object,), bert_config) - self.embeddings = BertEmbeddings(self.config) - self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) - self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) - self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): - language_features = self.embeddings(input_ids, token_type_ids) - for layer in self.vision_encoder: - hidden_state_vision = layer(vision_feats, None)[0] - for layer in self.language_encoder: - hidden_state_language = layer(language_features, attention_mask)[0] - for layer in self.mixed_encoder: - hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) - return hidden_state_mixed - - -class VLTransformers(torch.nn.Module): - """ - Vision Language Multimodal Transformers" - """ - - def __init__( - self, - in_channels: int, - img_size: Union[Sequence[int], int], # type: ignore - patch_size: Union[Sequence[int], int], # type: ignore - num_classes: int, - num_language_layers: int, - num_vision_layers: int, - num_mixed_layers: int, - drop_out: float = 0.0, - bert_config: dict = { - "attention_probs_dropout_prob": 0.1, - "classifier_dropout": None, - "gradient_checkpointing": False, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 512, - "model_type": "bert", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "transformers_version": "4.10.2", - "type_vocab_size": 2, - "use_cache": True, - "vocab_size": 30522, - "chunk_size_feed_forward": 0, - "is_decoder": False, - "add_cross_attention": False, - }, - ) -> None: - """ - Args: - in_channels: dimension of input channels. - img_size: dimension of input image. - patch_size: dimension of patch size. - num_classes: number of classes if classification is used. - num_language_layers: number of language transformer layers. - num_vision_layers: number of vision transformer layers. - num_mixed_layers: number of mixed transformer layers. - drop_out: faction of the input units to drop. - bert_config: configuration for bert language transformer encoder. - Examples:: - # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, - 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head - >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, - num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) - """ - super(VLTransformers, self).__init__() - - if not (0 <= drop_out <= 1): - raise ValueError("dropout_rate should be in the range of 0 and 1.") - - if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore - raise ValueError("img_size should be divisible by patch_size.") - - self.multimodal = MultiModal.from_pretrained( - num_language_layers=num_language_layers, - num_vision_layers=num_vision_layers, - num_mixed_layers=num_mixed_layers, - bert_config=bert_config, - ) - - self.embed_dim = 768 - self.patch_size = patch_size - self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore - self.vision_proj = nn.Conv2d( - in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size - ) - self.norm_vision_pos = nn.LayerNorm(self.embed_dim) - self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) - self.pooler = Pooler(hidden_size=self.embed_dim) - self.drop = torch.nn.Dropout(drop_out) - self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) - - def forward(self, input_ids, token_type_ids=None, vision_feats=None): - attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) - attention_mask = (1.0 - attention_mask) * -10000.0 - vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) - vision_feats = self.norm_vision_pos(vision_feats) - vision_feats = vision_feats + self.pos_embed_vis - hidden_state_mixed = self.multimodal( - input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask - ) - pooled_features = self.pooler(hidden_state_mixed) - logits = self.cls_head(self.drop(pooled_features)) - return logits From 0212c61b353cb82805d3a3e1517d2c32ae398add Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 20:54:59 -0700 Subject: [PATCH 05/11] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 83 +++++++++++++++++----------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index af095a181c..97661f7618 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -270,31 +270,28 @@ def __init__( num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, + hidden_size: int = 768, drop_out: float = 0.0, - bert_config: dict = { - "attention_probs_dropout_prob": 0.1, - "classifier_dropout": None, - "gradient_checkpointing": False, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 512, - "model_type": "bert", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "transformers_version": "4.10.2", - "type_vocab_size": 2, - "use_cache": True, - "vocab_size": 30522, - "chunk_size_feed_forward": 0, - "is_decoder": False, - "add_cross_attention": False, - }, + attention_probs_dropout_prob: float = 0.1, + gradient_checkpointing: bool = False, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + initializer_range: float = 0.02, + intermediate_size: int = 3072, + layer_norm_eps: float = 1e-12, + max_position_embeddings: int = 512, + model_type: str = "bert", + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + pad_token_id: int = 0, + position_embedding_type: str = "absolute", + transformers_version: str = "4.10.2", + type_vocab_size: int = 2, + use_cache: bool = True, + vocab_size: int = 30522, + chunk_size_feed_forward: int = 0, + is_decoder: bool = False, + add_cross_attention: bool = False, ) -> None: """ Args: @@ -314,9 +311,32 @@ def __init__( num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) """ super(VLTransformers, self).__init__() - + bert_config = { + "attention_probs_dropout_prob": attention_probs_dropout_prob, + "classifier_dropout": None, + "gradient_checkpointing": gradient_checkpointing, + "hidden_act": hidden_act, + "hidden_dropout_prob": hidden_dropout_prob, + "hidden_size": hidden_size, + "initializer_range": initializer_range, + "intermediate_size": intermediate_size, + "layer_norm_eps": layer_norm_eps, + "max_position_embeddings": max_position_embeddings, + "model_type": model_type, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "pad_token_id": pad_token_id, + "position_embedding_type": position_embedding_type, + "transformers_version": transformers_version, + "type_vocab_size": type_vocab_size, + "use_cache": use_cache, + "vocab_size": vocab_size, + "chunk_size_feed_forward": chunk_size_feed_forward, + "is_decoder": is_decoder, + "add_cross_attention": add_cross_attention, + } if not (0 <= drop_out <= 1): - raise ValueError("dropout_rate should be in the range of 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore raise ValueError("img_size should be divisible by patch_size.") @@ -328,17 +348,16 @@ def __init__( bert_config=bert_config, ) - self.embed_dim = 768 self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( - in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size ) - self.norm_vision_pos = nn.LayerNorm(self.embed_dim) - self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) - self.pooler = Pooler(hidden_size=self.embed_dim) + self.norm_vision_pos = nn.LayerNorm(hidden_size) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) + self.pooler = Pooler(hidden_size=hidden_size) self.drop = torch.nn.Dropout(drop_out) - self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + self.cls_head = torch.nn.Linear(hidden_size, num_classes) def forward(self, input_ids, token_type_ids=None, vision_feats=None): attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) From 0ed73ed709b238574860d8a8561fb9d4f198a464 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 21:05:11 -0700 Subject: [PATCH 06/11] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/requirements.txt | 1 + monai/transforms/spatial/array.py | 2 +- requirements-dev.txt | 1 + setup.cfg | 3 +++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..47176c58a2 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c49f4e6479..cb80c2036b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -526,7 +526,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) - output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).squeeze(0) + output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) self._rotation_matrix = transform out: NdarrayOrTensor out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..ed8739ded8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ openslide-python==1.1.2 pandas requests einops +transformers diff --git a/setup.cfg b/setup.cfg index 6efe768a6f..f7ed90a14a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ all = openslide-python==1.1.2 pandas einops + transformers nibabel = nibabel skimage = @@ -74,6 +75,8 @@ pandas = pandas einops = einops +transformers = + transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 From 344d6216427df9b54de4d44de29b4865d00054c0 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 21:16:57 -0700 Subject: [PATCH 07/11] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/source/installation.md | 4 ++-- monai/config/deviceconfig.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..4bc4aa700a 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops` and `transformers`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..ff45b29531 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,6 +73,7 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") return output From a4d4da4d4eec2031f197a3f127c9d81e8fb2d220 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 21:26:24 -0700 Subject: [PATCH 08/11] add multimodal transformers Signed-off-by: ahatamizadeh --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..bac6521889 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", + "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From fab0b8a4d0be0d3aa46ee8cde9b13fce0499c1b6 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 21:46:09 -0700 Subject: [PATCH 09/11] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 97661f7618..95f0575f74 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -14,7 +14,7 @@ import shutil import tarfile import tempfile -from typing import Sequence, Union +from typing import Sequence, Tuple, Union import torch from torch import nn @@ -265,7 +265,7 @@ def __init__( self, in_channels: int, img_size: Union[Sequence[int], int], # type: ignore - patch_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[int, Tuple[int, int]], # type: ignore num_classes: int, num_language_layers: int, num_vision_layers: int, @@ -351,7 +351,10 @@ def __init__( self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( - in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=self.patch_size, # type: ignore + stride=self.patch_size, # type: ignore ) self.norm_vision_pos = nn.LayerNorm(hidden_size) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) From b2fef98811fafcb1415c540d068711c07e3da512 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 22:50:47 -0700 Subject: [PATCH 10/11] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/_version.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/_version.py b/monai/_version.py index fb3a60690e..79f569dd79 100644 --- a/monai/_version.py +++ b/monai/_version.py @@ -23,9 +23,9 @@ def get_keywords(): # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). - git_refnames = " (HEAD -> dev, tag: 0.7.0rc1, releasing/0.7.0)" - git_full = "0f17aa991592fc6e635e86da3061b5dd3d669597" - git_date = "2021-09-15 21:03:08 +0000" + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords From aa62d5f2ba15d887e17cb8368b4a999983217d3b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 16 Sep 2021 07:32:21 +0100 Subject: [PATCH 11/11] adds docs Signed-off-by: Wenqi Li --- docs/source/networks.rst | 5 +++++ monai/networks/nets/__init__.py | 9 ++++++++ monai/networks/nets/autoencoder.py | 3 +++ monai/networks/nets/varautoencoder.py | 1 + monai/networks/nets/vit.py | 2 ++ monai/networks/nets/vltransformer.py | 30 ++++++++++++++++++++++----- 6 files changed, 45 insertions(+), 5 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 54c2756535..1020ff1026 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -500,6 +500,11 @@ Nets .. autoclass:: Critic :members: +`VLTransformers` +~~~~~~~~~~~~~~~~ +.. autoclass:: VLTransformers + :members: + `NetAdapter` ~~~~~~~~~~~~ .. autoclass:: NetAdapter diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ad1ca2418b..0ddda1d6dd 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -84,4 +84,13 @@ from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT +from .vltransformer import ( + BertAttention, + BertMixedLayer, + BertOutput, + BertPreTrainedModel, + MultiModal, + Pooler, + VLTransformers, +) from .vnet import VNet diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index ed5e351779..b7dc309b71 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -49,8 +49,11 @@ def __init__( dimensions: Optional[int] = None, ) -> None: """ + Initialize the AutoEncoder. + .. deprecated:: 0.6.0 ``dimensions`` is deprecated, use ``spatial_dims`` instead. + """ super().__init__() diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 3baa59531a..a228efab07 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -47,6 +47,7 @@ class VarAutoEncoder(AutoEncoder): .. deprecated:: 0.6.0 ``dimensions`` is deprecated, use ``spatial_dims`` instead. + """ @deprecated_arg( diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 3a5d94cc37..35e05727e2 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -18,6 +18,8 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +__all__ = ["ViT"] + class ViT(nn.Module): """ diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 95f0575f74..23a1a39ded 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -27,6 +27,16 @@ BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] +__all__ = [ + "BertPreTrainedModel", + "BertAttention", + "BertOutput", + "BertMixedLayer", + "Pooler", + "MultiModal", + "VLTransformers", +] + class BertPreTrainedModel(nn.Module): """Module to load BERT pre-trained weights. @@ -258,7 +268,7 @@ def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_m class VLTransformers(torch.nn.Module): """ - Vision Language Multimodal Transformers" + Vision Language Multimodal Transformers """ def __init__( @@ -304,11 +314,21 @@ def __init__( num_mixed_layers: number of mixed transformer layers. drop_out: faction of the input units to drop. bert_config: configuration for bert language transformer encoder. - Examples:: + + Examples: + + .. code-block:: python + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, - 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head - >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, - num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) + # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + net = VLTransformers(in_channels=3, + img_size=(224, 224), + num_classes=3, + num_language_layers=2, + num_vision_layers=2, + num_mixed_layers=2, + drop_out=0.2) + """ super(VLTransformers, self).__init__() bert_config = {