From 34959589745d4f3c327bc6362d3e320e09dac134 Mon Sep 17 00:00:00 2001 From: Ali Hatamizadeh Date: Tue, 3 May 2022 16:27:38 -0700 Subject: [PATCH 01/28] add swin_unetr model (#4074) * add swin_unetr model Signed-off-by: kbressem --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/patchembedding.py | 109 ++- monai/networks/layers/__init__.py | 2 + monai/networks/layers/drop_path.py | 45 ++ monai/networks/layers/weight_init.py | 64 ++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/swin_unetr.py | 982 ++++++++++++++++++++++++ tests/test_drop_path.py | 43 ++ tests/test_patchembedding.py | 37 +- tests/test_swin_unetr.py | 89 +++ tests/test_weight_init.py | 47 ++ 11 files changed, 1397 insertions(+), 24 deletions(-) create mode 100644 monai/networks/layers/drop_path.py create mode 100644 monai/networks/layers/weight_init.py create mode 100644 monai/networks/nets/swin_unetr.py create mode 100644 tests/test_drop_path.py create mode 100644 tests/test_swin_unetr.py create mode 100644 tests/test_weight_init.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 0fdc944760..b6328734b0 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -20,7 +20,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .mlp import MLPBlock -from .patchembedding import PatchEmbeddingBlock +from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 4c7263c6d5..f02f6342e8 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import math -from typing import Sequence, Union +from typing import Sequence, Type, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm -from monai.networks.layers import Conv +from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option @@ -98,34 +98,18 @@ def __init__( ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def trunc_normal_(self, tensor, mean, std, a, b): - # From PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - def forward(self, x): x = self.patch_embeddings(x) if self.pos_embed == "conv": @@ -133,3 +117,84 @@ def forward(self, x): embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings + + +class PatchEmbed(nn.Module): + """ + Patch embedding block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if + specified (3) position embedding is not used. + + Example:: + + >>> from monai.networks.blocks import PatchEmbed + >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3) + """ + + def __init__( + self, + patch_size: Union[Sequence[int], int] = 2, + in_chans: int = 1, + embed_dim: int = 48, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + spatial_dims: int = 3, + ) -> None: + """ + Args: + patch_size: dimension of patch size. + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + norm_layer: normalization layer. + spatial_dims: spatial dimension. + """ + + super().__init__() + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.proj = Conv[Conv.CONV, spatial_dims]( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + _, _, d, h, w = x_shape + if w % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2])) + if h % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1])) + if d % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0])) + + elif len(x_shape) == 4: + _, _, h, w = x.size() + if w % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1])) + if h % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0])) + + x = self.proj(x) + if self.norm is not None: + x_shape = x.size() + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if len(x_shape) == 5: + d, wh, ww = x_shape[2], x_shape[3], x_shape[4] + x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww) + elif len(x_shape) == 4: + wh, ww = x_shape[2], x_shape[3] + x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww) + return x diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 5115c00af3..f122dccee6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding +from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args from .filtering import BilateralFilter, PHLFilter from .gmm import GaussianMixtureModel @@ -27,3 +28,4 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/drop_path.py b/monai/networks/layers/drop_path.py new file mode 100644 index 0000000000..7bb209ed25 --- /dev/null +++ b/monai/networks/layers/drop_path.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class DropPath(nn.Module): + """Stochastic drop paths per sample for residual blocks. + Based on: + https://github.com/rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: + """ + Args: + drop_prob: drop path probability. + scale_by_keep: scaling by non-dropped probability. + """ + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + if not (0 <= drop_prob <= 1): + raise ValueError("Drop path prob should be between 0 and 1.") + + def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py new file mode 100644 index 0000000000..9b81ef17f8 --- /dev/null +++ b/monai/networks/layers/weight_init.py @@ -0,0 +1,64 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + """Tensor initialization with truncated normal distribution. + Based on: + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor`. + mean: the mean of the normal distribution. + std: the standard deviation of the normal distribution. + a: the minimum cutoff value. + b: the maximum cutoff value. + """ + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + """Tensor initialization with truncated normal distribution. + Based on: + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + + if not std > 0: + raise ValueError("the standard deviation should be greater than zero.") + + if a >= b: + raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).") + + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 16686fa25c..394ff51907 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -80,6 +80,7 @@ seresnext50, seresnext101, ) +from .swin_unetr import SwinUNETR from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex from .unet import UNet, Unet diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py new file mode 100644 index 0000000000..d898da9884 --- /dev/null +++ b/monai/networks/nets/swin_unetr.py @@ -0,0 +1,982 @@ +# Copyright 2020 - 2022 -> (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.nn import LayerNorm + +from monai.networks.blocks import MLPBlock as Mlp +from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock +from monai.networks.layers import DropPath, trunc_normal_ +from monai.utils import ensure_tuple_rep, optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + + +class SwinUNETR(nn.Module): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + """ + + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 48, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + + Examples:: + + # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. + >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) + + # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. + >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) + + # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. + >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) + + """ + + super().__init__() + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(2, spatial_dims) + window_size = ensure_tuple_rep(7, spatial_dims) + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + for m, p in zip(img_size, patch_size): + for i in range(5): + if m % np.power(p, i + 1) != 0: + raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") + + if not (0 <= drop_rate <= 1): + raise ValueError("dropout rate should be between 0 and 1.") + + if not (0 <= attn_drop_rate <= 1): + raise ValueError("attention dropout rate should be between 0 and 1.") + + if not (0 <= dropout_path_rate <= 1): + raise ValueError("drop path rate should be between 0 and 1.") + + if feature_size % 12 != 0: + raise ValueError("feature_size should be divisible by 12.") + + self.normalize = normalize + + self.swinViT = SwinTransformer( + in_chans=in_channels, + embed_dim=feature_size, + window_size=window_size, + patch_size=patch_size, + depths=depths, + num_heads=num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dropout_path_rate, + norm_layer=nn.LayerNorm, + use_checkpoint=use_checkpoint, + spatial_dims=spatial_dims, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def load_from(self, weights): + + with torch.no_grad(): + self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) + self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) + for bname, block in self.swinViT.layers1[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers1") + self.swinViT.layers1[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.reduction.weight"] + ) + self.swinViT.layers1[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.weight"] + ) + self.swinViT.layers1[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers2[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers2") + self.swinViT.layers2[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.reduction.weight"] + ) + self.swinViT.layers2[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.weight"] + ) + self.swinViT.layers2[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers3[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers3") + self.swinViT.layers3[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.reduction.weight"] + ) + self.swinViT.layers3[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.weight"] + ) + self.swinViT.layers3[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers4[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers4") + self.swinViT.layers4[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.reduction.weight"] + ) + self.swinViT.layers4[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.weight"] + ) + self.swinViT.layers4[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.bias"] + ) + self.swinViT.norm.weight.copy_(weights["state_dict"]["module.norm.weight"]) + self.swinViT.norm.bias.copy_(weights["state_dict"]["module.norm.bias"]) + + def forward(self, x_in): + hidden_states_out = self.swinViT(x_in, self.normalize) + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + dec4 = self.encoder10(hidden_states_out[4]) + dec3 = self.decoder5(dec4, hidden_states_out[3]) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + return logits + + +def window_partition(x, window_size): + """window partition operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x: input tensor. + window_size: local window size. + """ + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + x = x.view( + b, + d // window_size[0], + window_size[0], + h // window_size[1], + window_size[1], + w // window_size[2], + window_size[2], + c, + ) + windows = ( + x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) + ) + elif len(x_shape) == 4: + b, h, w, c = x.shape + x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) + return windows + + +def window_reverse(windows, window_size, dims): + """window reverse operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + windows: windows tensor. + window_size: local window size. + dims: dimension values. + """ + if len(dims) == 4: + b, d, h, w = dims + x = windows.view( + b, + d // window_size[0], + h // window_size[1], + w // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) + + elif len(dims) == 3: + b, h, w = dims + x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """Computing window size based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x_size: input size. + window_size: local window size. + shift_size: window shifting size. + """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention(nn.Module): + """ + Window based multi-head self attention module with relative position bias based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + attn_drop: attention dropout rate. + proj_drop: dropout rate of output. + """ + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + mesh_args = torch.meshgrid.__kwdefaults__ + + if len(self.window_size) == 3: + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + num_heads, + ) + ) + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + elif len(self.window_size) == 2: + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask): + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = q @ k.transpose(-2, -1) + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:n, :n].reshape(-1) + ].reshape(n, n, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + shift_size: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: str = "GELU", + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + shift_size: window shift size. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: stochastic depth rate. + act_layer: activation layer. + norm_layer: normalization layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") + + def forward_part1(self, x, mask_matrix): + x_shape = x.size() + x = self.norm1(x) + if len(x_shape) == 5: + b, d, h, w, c = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] + pad_b = (window_size[1] - h % window_size[1]) % window_size[1] + pad_r = (window_size[2] - w % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, dp, hp, wp, _ = x.shape + dims = [b, dp, hp, wp] + + elif len(x_shape) == 4: + b, h, w, c = x.shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + pad_l = pad_t = 0 + pad_r = (window_size[0] - h % window_size[0]) % window_size[0] + pad_b = (window_size[1] - w % window_size[1]) % window_size[1] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, hp, wp, _ = x.shape + dims = [b, hp, wp] + + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + x_windows = window_partition(shifted_x, window_size) + attn_windows = self.attn(x_windows, mask=attn_mask) + attn_windows = attn_windows.view(-1, *(window_size + (c,))) + shifted_x = window_reverse(attn_windows, window_size, dims) + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + else: + x = shifted_x + + if len(x_shape) == 5: + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :d, :h, :w, :].contiguous() + elif len(x_shape) == 4: + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def load_from(self, weights, n_block, layer): + root = f"module.{layer}.0.blocks.{n_block}." + block_names = [ + "norm1.weight", + "norm1.bias", + "attn.relative_position_bias_table", + "attn.relative_position_index", + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "norm2.weight", + "norm2.bias", + "mlp.linear1.weight", + "mlp.linear1.bias", + "mlp.linear2.weight", + "mlp.linear2.bias", + ] + with torch.no_grad(): + self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) + self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) + self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) + self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) + self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) + self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) + self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) + self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) + self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) + self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) + self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) + self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) + self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) + self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) + + def forward(self, x, mask_matrix): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + return x + + +class PatchMerging(nn.Module): + """ + Patch merging layer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 + ) -> None: # type: ignore + """ + Args: + dim: number of feature channels. + norm_layer: normalization layer. + spatial_dims: number of spatial dims. + """ + + super().__init__() + self.dim = dim + if spatial_dims == 3: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8 * dim) + elif spatial_dims == 2: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + + elif len(x_shape) == 4: + b, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], -1) + + x = self.norm(x) + x = self.reduction(x) + return x + + +def compute_mask(dims, window_size, shift_size, device): + """Computing region masks based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + dims: dimension values. + window_size: local window size. + shift_size: shift size. + device: device. + """ + + cnt = 0 + + if len(dims) == 3: + d, h, w = dims + img_mask = torch.zeros((1, d, h, w, 1), device=device) + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + + elif len(dims) == 2: + h, w = dims + img_mask = torch.zeros((1, h, w, 1), device=device) + for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +class BasicLayer(nn.Module): + """ + Basic Swin Transformer layer in one stage based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: Sequence[int], + drop_path: list, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + drop: float = 0.0, + attn_drop: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + downsample: isinstance = None, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + depths: number of layers in each stage. + num_heads: number of attention heads. + window_size: local window size. + drop_path: stochastic depth rate. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + norm_layer: normalization layer. + downsample: downsample layer at the end of the layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.no_shift = tuple(0 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth) + ] + ) + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + b, c, d, h, w = x_shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c d h w -> b d h w c") + dp = int(np.ceil(d / window_size[0])) * window_size[0] + hp = int(np.ceil(h / window_size[1])) * window_size[1] + wp = int(np.ceil(w / window_size[2])) * window_size[2] + attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, d, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b d h w c -> b c d h w") + + elif len(x_shape) == 4: + b, c, h, w = x_shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c h w -> b h w c") + hp = int(np.ceil(h / window_size[0])) * window_size[0] + wp = int(np.ceil(w / window_size[1])) * window_size[1] + attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b h w c -> b c h w") + return x + + +class SwinTransformer(nn.Module): + """ + Swin Transformer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + in_chans: int, + embed_dim: int, + window_size: Sequence[int], + patch_size: Sequence[int], + depths: Sequence[int], + num_heads: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + patch_norm: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + window_size: local window size. + patch_size: patch size. + depths: number of layers in each stage. + num_heads: number of attention heads. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + drop_path_rate: stochastic depth rate. + norm_layer: normalization layer. + patch_norm: add normalization after patch embedding. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: spatial dimension. + """ + + super().__init__() + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.window_size = window_size + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, # type: ignore + spatial_dims=spatial_dims, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers1 = nn.ModuleList() + self.layers2 = nn.ModuleList() + self.layers3 = nn.ModuleList() + self.layers4 = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=PatchMerging, + use_checkpoint=use_checkpoint, + ) + if i_layer == 0: + self.layers1.append(layer) + elif i_layer == 1: + self.layers2.append(layer) + elif i_layer == 2: + self.layers3.append(layer) + elif i_layer == 3: + self.layers4.append(layer) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.norm = norm_layer(self.num_features) + + def proj_out(self, x, normalize=False): + if normalize: + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") + return x + + def forward(self, x, normalize=False): + x0 = self.patch_embed(x) + x0 = self.pos_drop(x0) + x0_out = self.proj_out(x0, normalize) + x1 = self.layers1[0](x0.contiguous()) + x1_out = self.proj_out(x1, normalize) + x2 = self.layers2[0](x1.contiguous()) + x2_out = self.proj_out(x2, normalize) + x3 = self.layers3[0](x2.contiguous()) + x3_out = self.proj_out(x3, normalize) + x4 = self.layers4[0](x3.contiguous()) + x4_out = self.proj_out(x4, normalize) + return [x0_out, x1_out, x2_out, x3_out, x4_out] diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py new file mode 100644 index 0000000000..f8ea454228 --- /dev/null +++ b/tests/test_drop_path.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import DropPath + +TEST_CASES = [ + [{"drop_prob": 0.0, "scale_by_keep": True}, (1, 8, 8)], + [{"drop_prob": 0.7, "scale_by_keep": False}, (2, 16, 16, 16)], + [{"drop_prob": 0.3, "scale_by_keep": True}, (6, 16, 12)], +] + +TEST_ERRORS = [[{"drop_prob": 2, "scale_by_keep": False}, (1, 24, 6)]] + + +class TestDropPath(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + dr_path = DropPath(**input_param) + out = dr_path(im) + self.assertEqual(out.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + DropPath(**input_param) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 4af2b47ba5..6971eb0463 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -13,10 +13,11 @@ from unittest import skipUnless import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -48,6 +49,26 @@ test_case[0]["spatial_dims"] = 2 # type: ignore TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) +TEST_CASE_PATCHEMBED = [] +for patch_size in [2]: + for in_chans in [1, 4]: + for img_size in [96]: + for embed_dim in [6, 12]: + for norm_layer in [nn.LayerNorm]: + for nd in [2, 3]: + test_case = [ + { + "patch_size": (patch_size,) * nd, + "in_chans": in_chans, + "embed_dim": embed_dim, + "norm_layer": norm_layer, + "spatial_dims": nd, + }, + (2, in_chans, *([img_size] * nd)), + (2, embed_dim, *([img_size // patch_size] * nd)), + ] + TEST_CASE_PATCHEMBED.append(test_case) + class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK) @@ -115,5 +136,19 @@ def test_ill_arg(self): ) +class TestPatchEmbed(unittest.TestCase): + @parameterized.expand(TEST_CASE_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = PatchEmbed(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbed(patch_size=(2, 2, 2), in_chans=1, embed_dim=24, norm_layer=nn.LayerNorm, spatial_dims=5) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py new file mode 100644 index 0000000000..0d48e99c44 --- /dev/null +++ b/tests/test_swin_unetr.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.swin_unetr import SwinUNETR +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_SWIN_UNETR = [] +for attn_drop_rate in [0.4]: + for in_channels in [1]: + for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: + for out_channels in [2]: + for img_size in [64]: + for feature_size in [12]: + for norm_name in ["instance"]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size,) * nd, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, out_channels, *([img_size] * nd)), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_SWIN_UNETR.append(test_case) + + +class TestSWINUNETR(unittest.TestCase): + @parameterized.expand(TEST_CASE_SWIN_UNETR) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SwinUNETR(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(128, 128, 128), + feature_size=24, + norm_name="instance", + attn_drop_rate=4, + ) + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=2, img_size=(96, 96), feature_size=48, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=4, img_size=(96, 96, 96), feature_size=50, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(85, 85, 85), + feature_size=24, + norm_name="instance", + drop_rate=0.4, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py new file mode 100644 index 0000000000..c850ff4ce6 --- /dev/null +++ b/tests/test_weight_init.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import trunc_normal_ + +TEST_CASES = [ + [{"mean": 0.0, "std": 1.0, "a": 2, "b": 4}, (6, 12, 3, 1, 7)], + [{"mean": 0.3, "std": 0.6, "a": -1.0, "b": 1.3}, (1, 4, 4, 4)], + [{"mean": 0.1, "std": 0.4, "a": 1.3, "b": 1.8}, (5, 7, 7, 8, 9)], +] + +TEST_ERRORS = [ + [{"mean": 0.0, "std": 1.0, "a": 5, "b": 1.1}, (1, 1, 8, 8, 8)], + [{"mean": 0.3, "std": -0.1, "a": 1.0, "b": 2.0}, (8, 5, 2, 6, 9)], + [{"mean": 0.7, "std": 0.0, "a": 0.1, "b": 2.0}, (4, 12, 23, 17)], +] + + +class TestWeightInit(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + self.assertEqual(im.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + + +if __name__ == "__main__": + unittest.main() From 06ccf35a27136341764c13cf8f813c76573851c0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 May 2022 12:37:41 +0800 Subject: [PATCH 02/28] 4217 Update PyTorch docker to 22.04 (#4218) * [DLMED] update to 22.04 Signed-off-by: Nic Ma * fixes unit test tests.test_lr_finder Signed-off-by: Wenqi Li * test new_empty Signed-off-by: Wenqi Li Co-authored-by: Wenqi Li Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Signed-off-by: kbressem --- .github/workflows/cron.yml | 6 +++--- .github/workflows/pythonapp-gpu.yml | 4 ++-- Dockerfile | 2 +- monai/data/meta_obj.py | 2 +- monai/data/meta_tensor.py | 15 +++++++++++++-- tests/test_meta_tensor.py | 3 +-- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 734a84ff2f..08065147e5 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -62,7 +62,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp. + container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -106,7 +106,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp. + container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -204,7 +204,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:22.03-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:22.04-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, common] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 50bbe13062..2fdfa5a80f 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -46,9 +46,9 @@ jobs: base: "nvcr.io/nvidia/pytorch:21.10-py3" - environment: PT111+CUDA116 # we explicitly set pytorch to -h to avoid pip install error - # 22.03: 1.12.0a0+2c916ef + # 22.04: 1.12.0a0+bd13bc6 pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:22.03-py3" + base: "nvcr.io/nvidia/pytorch:22.04-py3" - environment: PT110+CUDA102 pytorch: "torch==1.10.2 torchvision==0.11.3" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" diff --git a/Dockerfile b/Dockerfile index dc76584d5a..1b022fc92e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.03-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.04-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index e38e009e96..3e35a6dd4b 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -153,7 +153,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call Returns: Returns `None`, but `self` should be updated to have the copied attribute. """ - attributes = [getattr(i, attribute) for i in input_objs] + attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)] if len(attributes) > 0: val = attributes[0] if deep_copy: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9196f0186c..d44b780a5e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -186,8 +186,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: kwargs = {} ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. - if "out" in kwargs: - return ret + # if "out" in kwargs: + # return ret # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. @@ -232,3 +232,14 @@ def affine(self) -> torch.Tensor: def affine(self, d: torch.Tensor) -> None: """Set the affine.""" self.meta["affine"] = d + + def new_empty(self, size, dtype=None, device=None, requires_grad=False): + """ + must be defined for deepcopy to work + + See: + - https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty + """ + return type(self)( + self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad) + ) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 05356fcc84..fb6d922218 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -272,14 +272,13 @@ def test_amp(self): def test_out(self): """Test when `out` is given as an argument.""" m1, _ = self.get_im() - m1_orig = deepcopy(m1) m2, _ = self.get_im() m3, _ = self.get_im() torch.add(m2, m3, out=m1) m1_add = m2 + m3 assert_allclose(m1, m1_add) - self.check_meta(m1, m1_orig) + # self.check_meta(m1, m2) # meta is from first input tensor @parameterized.expand(TESTS) def test_collate(self, device, dtype): From 2cf927c7338f195eae37d8ec717910e4a2cc72a2 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Thu, 5 May 2022 13:26:52 +0800 Subject: [PATCH 03/28] Add InstanceNorm3dNVFuser support (#4194) * implement the base class Signed-off-by: Yiheng Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add unittest Signed-off-by: Yiheng Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * autofix Signed-off-by: Yiheng Wang * switch to call apex directly Signed-off-by: Yiheng Wang * uncomment unittest Signed-off-by: Yiheng Wang * add apex install link in docstring Signed-off-by: Yiheng Wang * add channels_last_3d test case Signed-off-by: Yiheng Wang * rewrite types Signed-off-by: Yiheng Wang * change types Signed-off-by: Yiheng Wang * add docstrings Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Signed-off-by: kbressem --- monai/networks/layers/factories.py | 29 ++++++++++++++++++++++++++- monai/networks/nets/dynunet.py | 2 ++ tests/test_dynunet.py | 32 +++++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 6379f49449..b808c24de0 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -60,11 +60,14 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ +import warnings from typing import Any, Callable, Dict, Tuple, Type, Union import torch.nn as nn -from monai.utils import look_up_option +from monai.utils import look_up_option, optional_import + +InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -242,6 +245,30 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: return nn.SyncBatchNorm +@Norm.factory_function("instance_nvfuser") +def instance_nvfuser_factory(dim): + """ + `InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`. + It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. + In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, + `nn.InstanceNorm3d` will be returned instead. + This layer is based on a customized autograd function, which is not supported in TorchScript currently. + Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary. + + Please check the following link for more details about how to install `apex`: + https://github.com/NVIDIA/apex#installation + + """ + types = (nn.InstanceNorm1d, nn.InstanceNorm2d) + if dim != 3: + warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") + return types[dim - 1] + if not has_nvfuser: + warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + return nn.InstanceNorm3d + return InstanceNorm3dNVFuser + + Act.add_factory_callable("elu", lambda: nn.modules.ELU) Act.add_factory_callable("relu", lambda: nn.modules.ReLU) Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index e858dcbb9b..053ab255b8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -104,6 +104,8 @@ class DynUNet(nn.Module): If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when: + 1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the final feature map diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 36ac9d0309..14006b96e6 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,10 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import test_script_save +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save + +_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -118,6 +121,33 @@ def test_script(self): test_script_save(net, test_data) +@skip_if_no_cuda +@skip_if_windows +@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") +class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): + @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) + def test_consistency(self, input_param, input_shape, _): + for eps in [1e-4, 1e-5]: + for momentum in [0.1, 0.01]: + for affine in [True, False]: + norm_param = {"eps": eps, "momentum": momentum, "affine": affine} + input_param["norm_name"] = ("instance", norm_param) + input_param_fuser = input_param.copy() + input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) + for memory_format in [torch.contiguous_format, torch.channels_last_3d]: + net = DynUNet(**input_param).to("cuda:0", memory_format=memory_format) + net_fuser = DynUNet(**input_param_fuser).to("cuda:0", memory_format=memory_format) + net_fuser.load_state_dict(net.state_dict()) + + input_tensor = torch.randn(input_shape).to("cuda:0", memory_format=memory_format) + with eval_mode(net): + result = net(input_tensor) + with eval_mode(net_fuser): + result_fuser = net_fuser(input_tensor) + + torch.testing.assert_close(result, result_fuser) + + class TestDynUNetDeepSupervision(unittest.TestCase): @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) def test_shape(self, input_param, input_shape, expected_shape): From c9aee9220db58bd3cd7ee1293cafaab0be9224b7 Mon Sep 17 00:00:00 2001 From: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> Date: Fri, 6 May 2022 07:10:59 -0400 Subject: [PATCH 04/28] Update dice.py (#4234) * Update dice.py reduce redundant operations in DiceFocalLoss, initially caused oom Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> * [MONAI] python code formatting Signed-off-by: monai-bot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot Signed-off-by: kbressem --- monai/losses/dice.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 610327ef63..67802abc0a 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -796,8 +796,6 @@ def __init__( """ super().__init__() self.dice = DiceLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, sigmoid=sigmoid, softmax=softmax, other_act=other_act, @@ -808,19 +806,15 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - self.focal = FocalLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, - gamma=gamma, - weight=focal_weight, - reduction=reduction, - ) + self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_focal < 0.0: raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal + self.to_onehot_y = to_onehot_y + self.include_background = include_background def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -837,6 +831,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss From 84c9c1590e9bbaa7510d92c4e8f47dd299256e7b Mon Sep 17 00:00:00 2001 From: Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Date: Fri, 6 May 2022 15:28:08 -0400 Subject: [PATCH 05/28] Bug fix and improvement in WSI (#4216) * Make all transforms optional Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update wsireader tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove optional from PersistentDataset and its derivatives Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for cache without transform Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add default replace_rate Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add default value Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Set default replace_rate to 0.1 Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update metadata to include path Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Adds SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update references Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove smart cache Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove unused imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add path metadata for OpenSlide Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update metadata to be unified across different backends Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update wsi metadata for multi wsi objects Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for wsi metadata Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: kbressem --- monai/data/wsi_datasets.py | 16 ++--- monai/data/wsi_reader.py | 118 +++++++++++++++++------------------- tests/test_wsireader_new.py | 29 ++++++--- 3 files changed, 85 insertions(+), 78 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index a895e8aa45..750b3fda20 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,7 +10,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np @@ -32,10 +32,12 @@ class PatchWSIDataset(Dataset): size: the size of patch to be extracted from the whole slide image. level: the level at which the patches to be extracted (default to 0). transform: transforms to be executed on input data. - reader: the module to be used for loading whole slide imaging, - - if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. - - if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. - - if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + reader: the module to be used for loading whole slide imaging. If `reader` is + + - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. + - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. + - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class Note: @@ -45,14 +47,14 @@ class PatchWSIDataset(Dataset): [ {"image": "path/to/image1.tiff", "location": [200, 500], "label": 0}, - {"image": "path/to/image2.tiff", "location": [100, 700], "label": 1} + {"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1} ] """ def __init__( self, - data: List, + data: Sequence, size: Optional[Union[int, Tuple[int, int]]] = None, level: Optional[int] = None, transform: Optional[Callable] = None, diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 02032a0ae6..8dee1f453e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -10,6 +10,7 @@ # limitations under the License. from abc import abstractmethod +from os.path import abspath from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader): """ supported_suffixes: List[str] = [] + backend = "" def __init__(self, level: int, **kwargs): super().__init__() @@ -63,7 +65,7 @@ def __init__(self, level: int, **kwargs): @abstractmethod def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -102,12 +109,14 @@ def get_patch( """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - @abstractmethod - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). @@ -115,7 +124,14 @@ def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple level: the level number. Defaults to 0 """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + metadata: Dict = { + "backend": self.backend, + "original_channel_dim": 0, + "spatial_shape": np.asarray(patch.shape[1:]), + "wsi": {"path": self.get_file_path(wsi)}, + "patch": {"location": location, "size": size, "level": level}, + } + return metadata def get_data( self, @@ -194,8 +210,26 @@ def get_data( patch_list.append(patch) # Set patch-related metadata - each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) - metadata.update(each_meta) + each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level) + + if len(wsi) == 1: + metadata = each_meta + else: + if not metadata: + metadata = { + "backend": each_meta["backend"], + "original_channel_dim": each_meta["original_channel_dim"], + "spatial_shape": each_meta["spatial_shape"], + "wsi": [each_meta["wsi"]], + "patch": [each_meta["patch"]], + } + else: + if metadata["original_channel_dim"] != each_meta["original_channel_dim"]: + raise ValueError("original_channel_dim is not consistent across wsi objects.") + if any(metadata["spatial_shape"] != each_meta["spatial_shape"]): + raise ValueError("spatial_shape is not consistent across wsi objects.") + metadata["wsi"].append(each_meta["wsi"]) + metadata["patch"].append(each_meta["patch"]) return _stack_images(patch_list, metadata), metadata @@ -247,7 +281,7 @@ def get_level_count(self, wsi) -> int: def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -256,19 +290,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]: """ return self.reader.get_size(wsi, level) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return self.reader.get_file_path(wsi) def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -317,6 +341,7 @@ class CuCIMWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "cucim" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -335,7 +360,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -344,27 +369,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "cucim", - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi.path)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ @@ -440,6 +447,7 @@ class OpenSlideWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "openslide" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -458,7 +466,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -467,27 +475,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "openslide", - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi._filename)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 2ac4125f97..4faec53978 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -125,8 +125,13 @@ class Tests(unittest.TestCase): def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) with reader.read(file_path) as img_obj: - img = reader.get_data(img_obj)[0] + img, meta = reader.get_data(img_obj) self.assertTupleEqual(img.shape, expected_shape) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], level) + self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], (0, 0)) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_read_region(self, file_path, patch_info, expected_img): @@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img): reader.get_data(img_obj, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_3]) - def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) - img_obj = reader.read(file_path, **kwargs) + img_obj_list = reader.read(file_path_list, **kwargs) if self.backend == "tifffile": with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] + reader.get_data(img_obj_list, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] - img2 = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj_list, **patch_info) + img2 = reader.get_data(img_obj_list, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0]))) + self.assertEqual(meta["patch"][0]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") From 63ce977cb578c4daee1f02b8a8b9825b0c20831d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 15:57:21 +0100 Subject: [PATCH 06/28] Replace module (#4245) * replace modules Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> * fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> * replace_module -> replace_modules Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> * fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> Signed-off-by: kbressem --- monai/networks/__init__.py | 2 + monai/networks/utils.py | 104 ++++++++++++++++++++++++++++++++++- tests/test_replace_module.py | 97 ++++++++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 tests/test_replace_module.py diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 76223dfaef..0543b11632 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -20,6 +20,8 @@ one_hot, pixelshuffle, predict_segmentation, + replace_modules, + replace_modules_temp, save_state, slice_channels, to_norm_affine, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f22be31524..34ea4f716e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,8 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -41,6 +42,8 @@ "save_state", "convert_to_torchscript", "meshgrid_ij", + "replace_modules", + "replace_modules_temp", ] @@ -551,3 +554,102 @@ def meshgrid_ij(*tensors): if pytorch_after(1, 10): return torch.meshgrid(*tensors, indexing="ij") return torch.meshgrid(*tensors) + + +def _replace_modules( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + out: List[Tuple[str, torch.nn.Module]], + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Helper function for :py:class:`monai.networks.utils.replace_modules`. + """ + if match_device: + devices = list({i.device for i in parent.parameters()}) + # if only one device for whole of model + if len(devices) == 1: + new_module.to(devices[0]) + idx = name.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = name[:idx] + parent = getattr(parent, parent_name) + name = name[idx + 1 :] + _out: List[Tuple[str, torch.nn.Module]] = [] + _replace_modules(parent, name, new_module, _out) + # prepend the parent name + out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] + # no "." in module name, do the actual replacing + else: + if strict_match: + old_module = getattr(parent, name) + setattr(parent, name, new_module) + out += [(name, old_module)] + else: + for mod_name, _ in parent.named_modules(): + if name in mod_name: + _replace_modules(parent, mod_name, deepcopy(new_module), out, strict_match=True) + + +def replace_modules( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> List[Tuple[str, torch.nn.Module]]: + """ + Replace sub-module(s) in a parent module. + + The name of the module to be replace can be nested e.g., + `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "." + in the module name), then this function will recursively call itself. + + Args: + parent: module that contains the module to be replaced + name: name of module to be replaced. Can include ".". + new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will + be deep copied if `strict_match == False` multiple instances are independent. + strict_match: if `True`, module name must `== name`. If false then + `name in named_modules()` will be used. `True` can be used to change just + one module, whereas `False` can be used to replace all modules with similar + name (e.g., `relu`). + match_device: if `True`, the device of the new module will match the model. Requires all + of `parent` to be on the same device. + + Returns: + List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module. + + Raises: + AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. + """ + out: List[Tuple[str, torch.nn.Module]] = [] + _replace_modules(parent, name, new_module, out, strict_match, match_device) + return out + + +@contextmanager +def replace_modules_temp( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +): + """ + Temporarily replace sub-module(s) in a parent module (context manager). + + See :py:class:`monai.networks.utils.replace_modules`. + """ + replaced: List[Tuple[str, torch.nn.Module]] = [] + try: + # replace + _replace_modules(parent, name, new_module, replaced, strict_match, match_device) + yield + finally: + # revert + for name, module in replaced: + _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py new file mode 100644 index 0000000000..4cb4443410 --- /dev/null +++ b/tests/test_replace_module.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional, Type + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet121 +from monai.networks.utils import replace_modules, replace_modules_temp +from tests.utils import TEST_DEVICES + +TESTS = [] +for device in TEST_DEVICES: + for match_device in (True, False): + # replace 1 + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", True, match_device, *device)) + # replace 1 (but not strict) + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", False, match_device, *device)) + # replace multiple + TESTS.append(("relu", False, match_device, *device)) + + +class TestReplaceModule(unittest.TestCase): + def setUp(self): + self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + self.num_relus = self.get_num_modules(torch.nn.ReLU) + self.total = self.get_num_modules() + self.assertGreater(self.num_relus, 0) + + def get_num_modules(self, mod: Optional[Type[torch.nn.Module]] = None) -> int: + m = [m for _, m in self.net.named_modules()] + if mod is not None: + m = [_m for _m in m if isinstance(_m, mod)] + return len(m) + + def check_replaced_modules(self, name, match_device): + # total num modules should remain the same + self.assertEqual(self.total, self.get_num_modules()) + num_relus_mod = self.get_num_modules(torch.nn.ReLU) + num_softmax = self.get_num_modules(torch.nn.Softmax) + # list of returned modules should be as long as number of softmax + self.assertEqual(self.num_relus, num_relus_mod + num_softmax) + if name == "relu": + # at least 2 softmaxes + self.assertGreaterEqual(num_softmax, 2) + else: + # one softmax + self.assertEqual(num_softmax, 1) + if match_device: + self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1) + + @parameterized.expand(TESTS) + def test_replace(self, name, strict_match, match_device, device): + self.net.to(device) + # replace module(s) + replaced = replace_modules(self.net, name, torch.nn.Softmax(), strict_match, match_device) + self.check_replaced_modules(name, match_device) + # number of returned modules should equal number of softmax modules + self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax)) + # all replaced modules should be ReLU + for r in replaced: + self.assertIsInstance(r[1], torch.nn.ReLU) + # if a specfic module was named, check that the name matches exactly + if name == "features.denseblock1.denselayer1.layers.relu1": + self.assertEqual(replaced[0][0], name) + + @parameterized.expand(TESTS) + def test_replace_context_manager(self, name, strict_match, match_device, device): + self.net.to(device) + with replace_modules_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): + self.check_replaced_modules(name, match_device) + # Check that model was correctly reverted + self.assertEqual(self.get_num_modules(), self.total) + self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus) + self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0) + + def test_raises(self): + # name doesn't exist in module + with self.assertRaises(AttributeError): + replace_modules(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True) + with self.assertRaises(AttributeError): + with replace_modules_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True): + pass + + +if __name__ == "__main__": + unittest.main() From a0b410048eef7e02cb457c92e68d1c3e4a0e3d9a Mon Sep 17 00:00:00 2001 From: Can Zhao <69829124+Can-Zhao@users.noreply.github.com> Date: Tue, 10 May 2022 12:22:22 -0400 Subject: [PATCH 07/28] Add GaussianSmooth as antialiasing filter in Resize (#4249) Signed-off-by: Can Zhao Signed-off-by: kbressem --- monai/transforms/spatial/array.py | 42 +++++++++++++++++++++++++++++++ tests/test_resize.py | 33 +++++++++++++++++++----- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6b67762b95..65df5d2b1b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -26,6 +26,7 @@ from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad +from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( create_control_grid, @@ -622,6 +623,15 @@ class Resize(Transform): align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. """ backend = [TransformBackends.TORCH] @@ -632,17 +642,23 @@ def __init__( size_mode: str = "all", mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, + anti_aliasing: bool = False, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, ) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma def __call__( self, img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, ) -> NdarrayOrTensor: """ Args: @@ -653,11 +669,23 @@ def __call__( align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing + anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if self.size_mode == "all": input_ndim = img_.ndim - 1 # spatial ndim @@ -677,6 +705,20 @@ def __call__( raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): + factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) + if anti_aliasing_sigma is None: + # if sigma is not given, use the default sigma in skimage.transform.resize + anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() + else: + # if sigma is given, use the given value for downsampling axis + anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(spatial_size_))) + for axis in range(len(spatial_size_)): + anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) + anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) + img_ = anti_aliasing_filter(img_) + resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, diff --git a/tests/test_resize.py b/tests/test_resize.py index 06246b2358..cb24cf2cc3 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -13,6 +13,7 @@ import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resize @@ -24,6 +25,10 @@ TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)] +TEST_CASE_3 = [{"spatial_size": 15, "anti_aliasing": True}, (6, 10, 15)] + +TEST_CASE_4 = [{"spatial_size": 6, "anti_aliasing": True, "anti_aliasing_sigma": 2.0}, (2, 4, 6)] + class TestResize(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -36,10 +41,15 @@ def test_invalid_inputs(self): resize(self.imt[0]) @parameterized.expand( - [((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")] + [ + ((32, -1), "area", True), + ((32, 32), "area", False), + ((32, 32, 32), "trilinear", True), + ((256, 256), "bilinear", False), + ] ) - def test_correct_results(self, spatial_size, mode): - resize = Resize(spatial_size, mode=mode) + def test_correct_results(self, spatial_size, mode, anti_aliasing): + resize = Resize(spatial_size, mode=mode, anti_aliasing=anti_aliasing) _order = 0 if mode.endswith("linear"): _order = 1 @@ -47,7 +57,7 @@ def test_correct_results(self, spatial_size, mode): spatial_size = (32, 64) expected = [ skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing ) for channel in self.imt[0] ] @@ -55,9 +65,20 @@ def test_correct_results(self, spatial_size, mode): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS: out = resize(p(self.imt[0])) - assert_allclose(out, expected, type_test=False, atol=0.9) + if not anti_aliasing: + assert_allclose(out, expected, type_test=False, atol=0.9) + else: + # skimage uses reflect padding for anti-aliasing filter. + # Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead. + # Thus their results near the image boundary will be different. + if isinstance(out, torch.Tensor): + out = out.cpu().detach().numpy() + good = np.sum(np.isclose(expected, out, atol=0.9)) + self.assertLessEqual( + np.abs(good - expected.size) / float(expected.size), 0.2, "at most 20 percent mismatch " + ) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_longest_shape(self, input_param, expected_shape): input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) input_param["size_mode"] = "longest" From 1cdff36f3d640a47ee713dbb70187f7a33610a3f Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Wed, 11 May 2022 17:10:47 +0800 Subject: [PATCH 08/28] 4235 fix 2204 nvfuser issue (#4241) * reproduce issue Signed-off-by: Yiheng Wang * remove 22.01 02 Signed-off-by: Yiheng Wang * remove other workflows Signed-off-by: Yiheng Wang * run on pull request Signed-off-by: Yiheng Wang * remove sleep Signed-off-by: Yiheng Wang * test single layer forward Signed-off-by: Yiheng Wang * add has_nvfuser Signed-off-by: Yiheng Wang * add check within factory Signed-off-by: Yiheng Wang * revert to original cron.yml Signed-off-by: Yiheng Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix old pt issue Signed-off-by: Yiheng Wang * change to return directly if no cuda Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: kbressem --- monai/networks/layers/factories.py | 19 +++++++++++++++++-- tests/test_dynunet.py | 11 ++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index b808c24de0..89fe1912a5 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -63,12 +63,14 @@ def use_factory(fact_args): import warnings from typing import Any, Callable, Dict, Tuple, Type, Union +import torch import torch.nn as nn from monai.utils import look_up_option, optional_import InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") + __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -263,8 +265,21 @@ def instance_nvfuser_factory(dim): if dim != 3: warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") return types[dim - 1] - if not has_nvfuser: - warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + # test InstanceNorm3dNVFuser installation with a basic example + has_nvfuser_flag = has_nvfuser + if not torch.cuda.is_available(): + return nn.InstanceNorm3d + try: + layer = InstanceNorm3dNVFuser(num_features=1, affine=True).to("cuda:0") + inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0") + out = layer(inp) + del inp, out, layer + except Exception: + has_nvfuser_flag = False + if not has_nvfuser_flag: + warnings.warn( + "`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead." + ) return nn.InstanceNorm3d return InstanceNorm3dNVFuser diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 14006b96e6..ff5d5efbef 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,11 +17,9 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from monai.utils import optional_import +from monai.utils.module import pytorch_after from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save -_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") - device = "cuda" if torch.cuda.is_available() else "cpu" strides: Sequence[Union[Sequence[int], int]] @@ -123,7 +121,6 @@ def test_script(self): @skip_if_no_cuda @skip_if_windows -@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) def test_consistency(self, input_param, input_shape, _): @@ -145,7 +142,11 @@ def test_consistency(self, input_param, input_shape, _): with eval_mode(net_fuser): result_fuser = net_fuser(input_tensor) - torch.testing.assert_close(result, result_fuser) + # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14 + if pytorch_after(1, 12): + torch.testing.assert_close(result, result_fuser) + else: + torch.testing.assert_allclose(result, result_fuser) class TestDynUNetDeepSupervision(unittest.TestCase): From 591931f14b24fbda00d2989f7c90f98c21b0bef1 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Wed, 11 May 2022 14:28:52 +0100 Subject: [PATCH 09/28] Update to Bundle Specifiation (#4250) * Update to bundle specifiation Signed-off-by: Eric Kerfoot * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding description in spec discussing the saved Torchscript object's file storage behaviour, and tweaking ckpt_export to add .json extension Signed-off-by: Eric Kerfoot * Annotating optional bundle files Signed-off-by: Eric Kerfoot * Adjusted ckpt_export test Signed-off-by: Eric Kerfoot * Fix Signed-off-by: Eric Kerfoot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: kbressem --- docs/source/mb_specification.rst | 23 ++++++++++++++++------- monai/bundle/scripts.py | 7 ++++++- tests/test_bundle_ckpt_export.py | 8 +++++--- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst index 1d286052a5..a88096f274 100644 --- a/docs/source/mb_specification.rst +++ b/docs/source/mb_specification.rst @@ -6,7 +6,7 @@ MONAI Bundle Specification Overview ======== -This is the specification for the MONAI Bundle (MB) format of portable described deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a model as a pickled state dictionary and/or a Torchscript object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. +This is the specification for the MONAI Bundle (MB) format of portable described deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. This specification defines the directory structure a bundle must have and the necessary files it must contain. Additional files may be included and the directory packaged into a zip file or included as extra files directly in a Torchscript file. @@ -22,26 +22,35 @@ A MONAI Bundle is defined primarily as a directory with a set of specifically na ┃ ┗━ metadata.json ┣━ models ┃ ┣━ model.pt - ┃ ┗━ model.ts + ┃ ┣━ *model.ts + ┃ ┗━ *model.onnx ┗━ docs - ┣━ README.md - ┗━ license.txt + ┣━ *README.md + ┗━ *license.txt -These files mostly are required to be present with the given names for the directory to define a valid bundle: +The following files are **required** to be present with the given filenames for the directory to define a valid bundle: * **metadata.json**: metadata information in JSON format relating to the type of model, definition of input and output tensors, versions of the model and used software, and other information described below. * **model.pt**: the state dictionary of a saved model, the information to instantiate the model must be found in the metadata file. + +The following files are optional but must have these names in the directory given above: + * **model.ts**: the Torchscript saved model if the model is compatible with being saved correctly in this format. +* **model.onnx**: the ONNX model if the model is compatible with being saved correctly in this format. * **README.md**: plain-language information on the model, how to use it, author information, etc. in Markdown format. * **license.txt**: software license attached to the model, can be left blank if no license needed. +Other files can be included in any of the above directories. For example, `configs` can contain further configuration JSON or YAML files to define scripts for training or inference, overriding configuration values, environment definitions such as network instantiations, and so forth. One common file to include is `inference.json` which is used to define a basic inference script which uses input files with the stored network to produce prediction output files. + Archive Format ============== -The bundle directory and its contents can be compressed into a zip file to constitute a single file package. When unzipped into a directory this file will reproduce the above directory structure, and should itself also be named after the model it contains. +The bundle directory and its contents can be compressed into a zip file to constitute a single file package. When unzipped into a directory this file will reproduce the above directory structure, and should itself also be named after the model it contains. For example, `ModelName.zip` would contain at least `ModelName/configs/metadata.json` and `ModelName/models/model.pt`, thus when unzipped would place files into the directory `ModelName` rather than into the current working directory. + +The Torchscript file format is also just a zip file with a specific structure. When creating such an archive with `save_net_with_metadata` a MB-compliant Torchscript file can be created by including the contents of `metadata.json` as the `meta_values` argument of the function, and other files included as `more_extra_files` entries. These will be stored in a `extras` directory in the zip file and can be retrieved with `load_net_with_metadata` or with any other library/tool that can read zip data. In this format the `model.*` files are obviously not needed but `README.md` and `license.txt` as well as any others provided can be added as more extra files. -The Torchscript file format is also just a zip file with a specific structure. When creating such an archive with `save_net_with_metadata` a MB-compliant Torchscript file can be created by including the contents of `metadata.json` as the `meta_values` argument of the function, and other files included as `more_extra_files` entries. These will be stored in a `extras` directory in the zip file and can be retrieved with `load_net_with_metadata` or with any other library/tool that can read zip data. In this format the `model.*` files are obviously not needed by `README.md` and `license.txt` can be added as more extra files. +The `bundle` submodule of MONAI contains a number of command line programs. To produce a Torchscript bundle use `ckpt_export` with a set of specified components such as the saved weights file and metadata file. Config files can be provided as JSON or YAML dictionaries defining Python constructs used by the `ConfigParser`, however regardless of format the produced bundle Torchscript object will store the files as JSON. metadata.json File ================== diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e5b306a90d..3c838d55a0 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -600,10 +600,15 @@ def ckpt_export( filename = os.path.basename(i) # remove extension filename, _ = os.path.splitext(filename) + # because all files are stored as JSON their name parts without extension must be unique if filename in extra_files: - raise ValueError(f"filename '{filename}' is given multiple times in config file list.") + raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") + # the file may be JSON or YAML but will get loaded and dumped out again as JSON extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + # add .json extension to all extra files which are always encoded as JSON + extra_files = {k + ".json": v for k, v in extra_files.items()} + save_net_with_metadata( jit_obj=net, filename_prefix_or_stream=filepath_, diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 36aa7319f0..a7cbff22f0 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -52,10 +52,12 @@ def test_export(self, key_in_ckpt): subprocess.check_call(cmd) self.assertTrue(os.path.exists(ts_file)) - _, metadata, extra_files = load_net_with_metadata(ts_file, more_extra_files=["inference", "def_args"]) + _, metadata, extra_files = load_net_with_metadata( + ts_file, more_extra_files=["inference.json", "def_args.json"] + ) self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args"])) - self.assertTrue("network_def" in json.loads(extra_files["inference"])) + self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) + self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) if __name__ == "__main__": From 90e2ac930364ad6fb4fec2ef450ccb79656eb311 Mon Sep 17 00:00:00 2001 From: kbressem Date: Wed, 11 May 2022 19:19:10 +0200 Subject: [PATCH 10/28] Implement NrrdReader and NrrdImage classes Signed-off-by: kbressem --- monai/data/image_reader.py | 129 ++++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 7e1db7ef7d..33fd733a59 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -27,19 +27,21 @@ import nibabel as nib from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage + import nrrd - has_itk = has_nib = has_pil = True + has_nrrd = has_itk = has_nib = has_pil = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") + nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) OpenSlide, _ = optional_import("openslide", name="OpenSlide") CuImage, _ = optional_import("cucim", name="CuImage") TiffFile, _ = optional_import("tifffile", name="TiffFile") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader", "NrrdReader"] class ImageReader(ABC): @@ -976,3 +978,126 @@ def _extract_patches( idx += 1 return flat_patch_grid + + +class NrrdImage(): + "Wrapper for image array and header" + + def __init__(self, + array: np.ndarray, + header: dict) -> None: + self.array = array + self.header = header + + +@require_pkg(pkg_name="nrrd") +class NrrdReader(ImageReader): + """ + Load NRRD format images based on pynrrd library. + + Args: + channel_dim: the channel dimension of the input image, default is None. + this is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field. + if None, `original_channel_dim` will be either `no_channel` or `0`. + NRRD files are usually "channel first". + dtype: dtype of the data array when loading image. + kwargs: additional args for `nrrd.read` API. more details about available args: + https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py + + """ + def __init__(self, + channel_dim: Optional[int] = None, + dtype: Union[np.dtype, type, str, None] = np.float32, + **kwargs): + self.channel_dim = channel_dim + self.dtype = dtype + self.kwargs = kwargs + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + """ + Verify whether the specified `filename` is supported by pynrrd reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + suffixes: Sequencec[str] = ["nrrd", "seg.nrrd"] + return has_nrrd and is_supported_format(filename, suffixes) + + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: + """ + Read image data from specified file or files. + Note that it returns a data object or a sequence of data objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for actual `read` API of 3rd party libs. + + """ + img_: List = [] + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + nrrd_image = NrrdImage(*nrrd.read(name, **kwargs_)) + img_.append(nrrd_image) + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]: + """ + Extract data array and meta data from loaded image and return them. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of meta data. + + Args: + img: an `NrrdImage` object loaded from an image file or a list of image objects. + + """ + img_array: List[NrrdImage] = [] + compatible_meta: Dict = {} + + for i in ensure_tuple(img): + data = self._get_array_data(i) + img_array.append(data) + header = dict(i.header) + header["original_affine"] = self._get_affine(i) + header["affine"] = header["original_affine"].copy() + header["spatial_shape"] = i.header["sizes"] + + if self.channel_dim is None: # default to "no_channel" or -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else 0 + else: + header["original_channel_dim"] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_array_data(self, img: NrrdImage) -> np.ndarray: + """ + Get the array data as Numpy array of `self.dtype` + + Args: + img: A `NrrdImage` loaded from image file + + """ + return img.array.astype(self.dtype) + + def _get_affine(self, img: NrrdImage) -> np.ndarray: + """ + Get the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: A `NrrdImage` loaded from image file + + """ + direction = img.header["space directions"] + origin = img.header["space origin"] + sr = min(max(direction.shape[0], 1), 3) + affine: np.ndarray = np.eye(sr + 1) + affine[:sr, :sr] = direction[:sr, :sr] + affine[:sr, -1] = origin[:sr] + flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # nrrd to nibabel affine + affine = np.diag(flip_diag) @ affine + return affine From 9535e36b5db575e99a3dba8d676299c3ec85c0a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 May 2022 17:26:39 +0000 Subject: [PATCH 11/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: kbressem --- monai/data/image_reader.py | 46 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 33fd733a59..2be077fc23 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -27,7 +27,7 @@ import nibabel as nib from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - import nrrd + import nrrd has_nrrd = has_itk = has_nib = has_pil = True else: @@ -980,18 +980,18 @@ def _extract_patches( return flat_patch_grid -class NrrdImage(): +class NrrdImage(): "Wrapper for image array and header" - - def __init__(self, - array: np.ndarray, - header: dict) -> None: + + def __init__(self, + array: np.ndarray, + header: dict) -> None: self.array = array self.header = header - + @require_pkg(pkg_name="nrrd") -class NrrdReader(ImageReader): +class NrrdReader(ImageReader): """ Load NRRD format images based on pynrrd library. @@ -1005,10 +1005,10 @@ class NrrdReader(ImageReader): https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py """ - def __init__(self, + def __init__(self, channel_dim: Optional[int] = None, - dtype: Union[np.dtype, type, str, None] = np.float32, - **kwargs): + dtype: Union[np.dtype, type, str, None] = np.float32, + **kwargs): self.channel_dim = channel_dim self.dtype = dtype self.kwargs = kwargs @@ -1024,7 +1024,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ suffixes: Sequencec[str] = ["nrrd", "seg.nrrd"] return has_nrrd and is_supported_format(filename, suffixes) - + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: """ Read image data from specified file or files. @@ -1034,7 +1034,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq data: file name or a list of file names to read. kwargs: additional args for actual `read` API of 3rd party libs. - """ + """ img_: List = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1043,7 +1043,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq nrrd_image = NrrdImage(*nrrd.read(name, **kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - + def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]: """ Extract data array and meta data from loaded image and return them. @@ -1055,8 +1055,8 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, """ img_array: List[NrrdImage] = [] - compatible_meta: Dict = {} - + compatible_meta: Dict = {} + for i in ensure_tuple(img): data = self._get_array_data(i) img_array.append(data) @@ -1064,7 +1064,7 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, header["original_affine"] = self._get_affine(i) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = i.header["sizes"] - + if self.channel_dim is None: # default to "no_channel" or -1 header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else 0 else: @@ -1072,17 +1072,17 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta - + def _get_array_data(self, img: NrrdImage) -> np.ndarray: """ Get the array data as Numpy array of `self.dtype` - - Args: + + Args: img: A `NrrdImage` loaded from image file - + """ return img.array.astype(self.dtype) - + def _get_affine(self, img: NrrdImage) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct @@ -1090,7 +1090,7 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: Args: img: A `NrrdImage` loaded from image file - + """ direction = img.header["space directions"] origin = img.header["space origin"] From f4139f21cad3c550aaf1d400de3550b668a5ccb7 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 12:16:09 +0200 Subject: [PATCH 12/28] run auto style fixes on image_reader.py Signed-off-by: kbressem --- monai/data/image_reader.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 2be077fc23..635ff985df 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -25,9 +25,9 @@ if TYPE_CHECKING: import itk import nibabel as nib + import nrrd from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - import nrrd has_nrrd = has_itk = has_nib = has_pil = True else: @@ -980,12 +980,10 @@ def _extract_patches( return flat_patch_grid -class NrrdImage(): +class NrrdImage: "Wrapper for image array and header" - def __init__(self, - array: np.ndarray, - header: dict) -> None: + def __init__(self, array: np.ndarray, header: dict) -> None: self.array = array self.header = header @@ -1005,10 +1003,10 @@ class NrrdReader(ImageReader): https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py """ - def __init__(self, - channel_dim: Optional[int] = None, - dtype: Union[np.dtype, type, str, None] = np.float32, - **kwargs): + + def __init__( + self, channel_dim: Optional[int] = None, dtype: Union[np.dtype, type, str, None] = np.float32, **kwargs + ): self.channel_dim = channel_dim self.dtype = dtype self.kwargs = kwargs @@ -1098,6 +1096,6 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: affine: np.ndarray = np.eye(sr + 1) affine[:sr, :sr] = direction[:sr, :sr] affine[:sr, -1] = origin[:sr] - flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # nrrd to nibabel affine + flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # nrrd to nibabel affine affine = np.diag(flip_diag) @ affine return affine From 6a7b9e836d2e108d1beecf40774e67041efdd100 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 12:25:52 +0200 Subject: [PATCH 13/28] add NrrdReader to monai/data/__init__.py Signed-off-by: kbressem --- monai/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index d9af568508..e58c26944d 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,7 +34,7 @@ from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, From 59aed3e4bf3908b293b7fa696c57374af7d43851 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 19:54:17 +0200 Subject: [PATCH 14/28] Change the way spatial information is handled in NrrdReader Signed-off-by: kbressem --- monai/data/image_reader.py | 65 +++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 635ff985df..219639bf8b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -995,20 +995,26 @@ class NrrdReader(ImageReader): Args: channel_dim: the channel dimension of the input image, default is None. - this is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field. - if None, `original_channel_dim` will be either `no_channel` or `0`. + This is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `0`. NRRD files are usually "channel first". dtype: dtype of the data array when loading image. + index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). + Numpy is usually in C-order, but default on the NRRD header is F kwargs: additional args for `nrrd.read` API. more details about available args: https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py """ def __init__( - self, channel_dim: Optional[int] = None, dtype: Union[np.dtype, type, str, None] = np.float32, **kwargs + self, channel_dim: Optional[int] = None, + dtype: Union[np.dtype, type, str, None] = np.float32, + index_order: str = "F", + **kwargs ): self.channel_dim = channel_dim self.dtype = dtype + self.index_order = index_order self.kwargs = kwargs def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: @@ -1038,7 +1044,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, **kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] @@ -1059,10 +1065,14 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, data = self._get_array_data(i) img_array.append(data) header = dict(i.header) + if self.index_order == "C": + header = self._convert_F_to_C_order(header) header["original_affine"] = self._get_affine(i) + header = self._switch_lps_ras(header) header["affine"] = header["original_affine"].copy() - header["spatial_shape"] = i.header["sizes"] - + header["spatial_shape"] = header["sizes"] + [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header + if self.channel_dim is None: # default to "no_channel" or -1 header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else 0 else: @@ -1092,10 +1102,41 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: """ direction = img.header["space directions"] origin = img.header["space origin"] - sr = min(max(direction.shape[0], 1), 3) - affine: np.ndarray = np.eye(sr + 1) - affine[:sr, :sr] = direction[:sr, :sr] - affine[:sr, -1] = origin[:sr] - flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # nrrd to nibabel affine - affine = np.diag(flip_diag) @ affine + + x, y = direction.shape + affine_diam = min(x, y)+1 + affine: np.ndarray = np.eye(affine_diam) + affine[:x, :y] = direction + affine[:(affine_diam-1), -1] = origin # len origin is always affine_diam - 1 return affine + + def _switch_lps_ras(self, header: dict) -> dict: + """ + For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and + `space` argument in header accordingly. + + Args: + header: The image meta data as dict + + """ + if header["space"] == "left-posterior-superior": + header["space"] = "right-anterior-superior" + header["original_affine"] = orientation_ras_lps(header["original_affine"]) + return header + + def _convert_F_to_C_order(self, header: dict) -> dict: + """ + All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. + 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] + The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] + For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering + + Args: + header: The image meta data as dict + + """ + + header["space directions"] = np.rot90(np.flip(header["space directions"] , 0)) + header["space origin"] = header["space origin"][::-1] + header["sizes"] = header["sizes"][::-1] + return header \ No newline at end of file From 9c597842a553f659f1fecdeef7a174c4798353eb Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 19:55:22 +0200 Subject: [PATCH 15/28] add tests for NrrdReader Signed-off-by: kbressem --- tests/test_init_reader.py | 8 ++- tests/test_nrrd_reader.py | 117 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 tests/test_nrrd_reader.py diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index 03a63cc375..0c020c727e 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -11,7 +11,7 @@ import unittest -from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader +from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader from monai.transforms import LoadImage, LoadImaged from tests.utils import SkipIfNoModule @@ -23,13 +23,14 @@ def test_load_image(self): self.assertIsInstance(instance1, LoadImage) self.assertIsInstance(instance2, LoadImage) - for r in ["NibabelReader", "PILReader", "ITKReader", "NumpyReader", None]: + for r in ["NibabelReader", "PILReader", "ITKReader", "NumpyReader", "NrrdReader", None]: inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) @SkipIfNoModule("itk") @SkipIfNoModule("nibabel") @SkipIfNoModule("PIL") + @SkipIfNoModule("nrrd") def test_readers(self): inst = ITKReader() self.assertIsInstance(inst, ITKReader) @@ -46,6 +47,9 @@ def test_readers(self): inst = PILReader() self.assertIsInstance(inst, PILReader) + + inst = NrrdReader() + self.assertIsInstance(inst, NrrdReader) if __name__ == "__main__": diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py new file mode 100644 index 0000000000..50e38650f4 --- /dev/null +++ b/tests/test_nrrd_reader.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized +import nrrd + +from monai.data import NrrdReader + +TEST_CASE_1 = [(4, 4), "test_image.nrrd", (4, 4), np.uint8] +TEST_CASE_2 = [(4, 4, 4), "test_image.nrrd", (4, 4, 4), np.uint16] +TEST_CASE_3 = [(4, 4, 4, 4), "test_image.nrrd", (4, 4, 4, 4), np.uint32] +TEST_CASE_4 = [(1, 2, 3, 4, 5), "test_image.nrrd", (1, 2, 3, 4, 5), np.uint64] +TEST_CASE_5 = [(6, 5, 4, 3, 2, 1), "test_image.nrrd", (6, 5, 4, 3, 2, 1), np.float32] +TEST_CASE_6 = [(4,), "test_image.nrrd", (4,), np.float64] +TEST_CASE_7 = [(4,4), ["test_image.nrrd", "test_image2.nrrd", "test_image3.nrrd"], (4,4), np.float32] +TEST_CASE_8 = [(3,4,4,1), "test_image.nrrd", (3,4,4,1), np.float32, + { + "dimension": 4, + "space": "left-posterior-superior", + "sizes": [3,4,4,1], + "space directions": [[0.,0.,0.], [1.,0.,0.],[0.,1.,0.],[0.,0.,1.]], + "space origin": [0.,0.,0.] + }] + +class TestNrrdReader(unittest.TestCase): + + def test_verify_suffix(self): + reader = NrrdReader() + self.assertFalse(reader.verify_suffix("test_image.nrd")) + reader.verify_suffix("test_image.nrrd") + reader.verify_suffix("test_image.seg.nrrd") + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_read_int(self, data_shape, filename, expected_shape, dtype): + min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max + test_image = np.random.randint(min_val, max_val, size=data_shape, dtype=dtype) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, filename) + nrrd.write(filename, test_image.astype(dtype)) + reader = NrrdReader() + result = reader.read(filename) + self.assertEqual(result.array.dtype, dtype) + self.assertTupleEqual(result.array.shape, expected_shape) + self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) + np.testing.assert_allclose(result.array, test_image) + + @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + def test_read_float(self, data_shape, filename, expected_shape, dtype): + test_image = np.random.rand(*data_shape).astype(dtype) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, filename) + nrrd.write(filename, test_image.astype(dtype)) + reader = NrrdReader() + result = reader.read(filename) + self.assertEqual(result.array.dtype, dtype) + self.assertTupleEqual(result.array.shape, expected_shape) + self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) + np.testing.assert_allclose(result.array, test_image) + + @parameterized.expand([TEST_CASE_7]) + def test_read_list(self, data_shape, filenames, expected_shape, dtype): + test_image = np.random.rand(*data_shape).astype(dtype) + with tempfile.TemporaryDirectory() as tempdir: + for i, filename in enumerate(filenames): + filenames[i] = os.path.join(tempdir, filename) + nrrd.write(filenames[i], test_image.astype(dtype)) + reader = NrrdReader() + results = reader.read(filenames) + for result in results: + self.assertTupleEqual(result.array.shape, expected_shape) + self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) + np.testing.assert_allclose(result.array, test_image) + + @parameterized.expand([TEST_CASE_8]) + def test_read_with_header(self, data_shape, filename, expected_shape, dtype, reference_header): + test_image = np.random.rand(*data_shape).astype(dtype) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, filename) + nrrd.write(filename, test_image.astype(dtype), header=reference_header) + reader = NrrdReader() + image_array, image_header = reader.get_data(reader.read(filename)) + self.assertIsInstance(image_array, np.ndarray) + self.assertEqual(image_array.dtype, dtype) + self.assertTupleEqual(image_array.shape, expected_shape) + np.testing.assert_allclose(image_array, test_image) + self.assertIsInstance(image_header, dict) + self.assertTupleEqual(tuple(image_header["spatial_shape"]), expected_shape) + + @parameterized.expand([TEST_CASE_8]) + def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header): + test_image = np.random.rand(*data_shape).astype(dtype) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, filename) + nrrd.write(filename, test_image.astype(dtype), header=reference_header) + reader = NrrdReader(index_order="C") + image_array, image_header = reader.get_data(reader.read(filename)) + self.assertIsInstance(image_array, np.ndarray) + self.assertEqual(image_array.dtype, dtype) + self.assertTupleEqual(image_array.shape, expected_shape[::-1]) + self.assertTupleEqual(image_array.shape, tuple(image_header["spatial_shape"])) + + +if __name__ == "__main__": + unittest.main() From 76329cb5de5eab9fd122aaaa109b8eaec4c64251 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 19:59:59 +0200 Subject: [PATCH 16/28] Add NrrdReader to list of possible readers for LoadImage Signed-off-by: kbressem --- monai/transforms/io/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5bafd84eaf..3862ef9a80 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -28,7 +28,7 @@ from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer from monai.data.folder_layout import FolderLayout -from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode @@ -37,11 +37,13 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") +nrrd, _ optional_import("nrrd") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] SUPPORTED_READERS = { "itkreader": ITKReader, + "nrrdreader": NrrdReader, "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, @@ -85,7 +87,7 @@ class LoadImage(Transform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (DICOM file -> ITKReader). + (npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader). See also: From 46169c083ba693684dd2fde4f6ec1cce31ca55d4 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 20:03:39 +0200 Subject: [PATCH 17/28] autofix formating Signed-off-by: kbressem --- monai/data/image_reader.py | 41 +++++++++++++++--------------- monai/transforms/io/array.py | 4 +-- tests/test_init_reader.py | 4 +-- tests/test_nrrd_reader.py | 49 ++++++++++++++++++++---------------- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 219639bf8b..3ac0491b18 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -999,7 +999,7 @@ class NrrdReader(ImageReader): If None, `original_channel_dim` will be either `no_channel` or `0`. NRRD files are usually "channel first". dtype: dtype of the data array when loading image. - index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). + index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). Numpy is usually in C-order, but default on the NRRD header is F kwargs: additional args for `nrrd.read` API. more details about available args: https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py @@ -1007,10 +1007,11 @@ class NrrdReader(ImageReader): """ def __init__( - self, channel_dim: Optional[int] = None, - dtype: Union[np.dtype, type, str, None] = np.float32, + self, + channel_dim: Optional[int] = None, + dtype: Union[np.dtype, type, str, None] = np.float32, index_order: str = "F", - **kwargs + **kwargs, ): self.channel_dim = channel_dim self.dtype = dtype @@ -1071,8 +1072,8 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, header = self._switch_lps_ras(header) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = header["sizes"] - [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header - + [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header + if self.channel_dim is None: # default to "no_channel" or -1 header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else 0 else: @@ -1104,20 +1105,20 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: origin = img.header["space origin"] x, y = direction.shape - affine_diam = min(x, y)+1 + affine_diam = min(x, y) + 1 affine: np.ndarray = np.eye(affine_diam) affine[:x, :y] = direction - affine[:(affine_diam-1), -1] = origin # len origin is always affine_diam - 1 + affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 return affine - + def _switch_lps_ras(self, header: dict) -> dict: """ - For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and - `space` argument in header accordingly. - - Args: + For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and + `space` argument in header accordingly. + + Args: header: The image meta data as dict - + """ if header["space"] == "left-posterior-superior": header["space"] = "right-anterior-superior" @@ -1130,13 +1131,13 @@ def _convert_F_to_C_order(self, header: dict) -> dict: 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering - - Args: + + Args: header: The image meta data as dict - + """ - - header["space directions"] = np.rot90(np.flip(header["space directions"] , 0)) + + header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 3862ef9a80..fc34985903 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -28,7 +28,7 @@ from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer from monai.data.folder_layout import FolderLayout -from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader +from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode @@ -37,7 +37,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") -nrrd, _ optional_import("nrrd") +nrrd, _ = optional_import("nrrd") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index 0c020c727e..df055e571c 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -11,7 +11,7 @@ import unittest -from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader +from monai.data import ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader from monai.transforms import LoadImage, LoadImaged from tests.utils import SkipIfNoModule @@ -47,7 +47,7 @@ def test_readers(self): inst = PILReader() self.assertIsInstance(inst, PILReader) - + inst = NrrdReader() self.assertIsInstance(inst, NrrdReader) diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py index 50e38650f4..5561d471ba 100644 --- a/tests/test_nrrd_reader.py +++ b/tests/test_nrrd_reader.py @@ -13,9 +13,9 @@ import tempfile import unittest +import nrrd import numpy as np from parameterized import parameterized -import nrrd from monai.data import NrrdReader @@ -25,24 +25,29 @@ TEST_CASE_4 = [(1, 2, 3, 4, 5), "test_image.nrrd", (1, 2, 3, 4, 5), np.uint64] TEST_CASE_5 = [(6, 5, 4, 3, 2, 1), "test_image.nrrd", (6, 5, 4, 3, 2, 1), np.float32] TEST_CASE_6 = [(4,), "test_image.nrrd", (4,), np.float64] -TEST_CASE_7 = [(4,4), ["test_image.nrrd", "test_image2.nrrd", "test_image3.nrrd"], (4,4), np.float32] -TEST_CASE_8 = [(3,4,4,1), "test_image.nrrd", (3,4,4,1), np.float32, - { - "dimension": 4, - "space": "left-posterior-superior", - "sizes": [3,4,4,1], - "space directions": [[0.,0.,0.], [1.,0.,0.],[0.,1.,0.],[0.,0.,1.]], - "space origin": [0.,0.,0.] - }] +TEST_CASE_7 = [(4, 4), ["test_image.nrrd", "test_image2.nrrd", "test_image3.nrrd"], (4, 4), np.float32] +TEST_CASE_8 = [ + (3, 4, 4, 1), + "test_image.nrrd", + (3, 4, 4, 1), + np.float32, + { + "dimension": 4, + "space": "left-posterior-superior", + "sizes": [3, 4, 4, 1], + "space directions": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + "space origin": [0.0, 0.0, 0.0], + }, +] + class TestNrrdReader(unittest.TestCase): - - def test_verify_suffix(self): + def test_verify_suffix(self): reader = NrrdReader() self.assertFalse(reader.verify_suffix("test_image.nrd")) reader.verify_suffix("test_image.nrrd") reader.verify_suffix("test_image.seg.nrrd") - + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_read_int(self, data_shape, filename, expected_shape, dtype): min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max @@ -56,7 +61,7 @@ def test_read_int(self, data_shape, filename, expected_shape, dtype): self.assertTupleEqual(result.array.shape, expected_shape) self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) np.testing.assert_allclose(result.array, test_image) - + @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) def test_read_float(self, data_shape, filename, expected_shape, dtype): test_image = np.random.rand(*data_shape).astype(dtype) @@ -69,7 +74,7 @@ def test_read_float(self, data_shape, filename, expected_shape, dtype): self.assertTupleEqual(result.array.shape, expected_shape) self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) np.testing.assert_allclose(result.array, test_image) - + @parameterized.expand([TEST_CASE_7]) def test_read_list(self, data_shape, filenames, expected_shape, dtype): test_image = np.random.rand(*data_shape).astype(dtype) @@ -83,9 +88,9 @@ def test_read_list(self, data_shape, filenames, expected_shape, dtype): self.assertTupleEqual(result.array.shape, expected_shape) self.assertTupleEqual(tuple(result.header["sizes"]), expected_shape) np.testing.assert_allclose(result.array, test_image) - + @parameterized.expand([TEST_CASE_8]) - def test_read_with_header(self, data_shape, filename, expected_shape, dtype, reference_header): + def test_read_with_header(self, data_shape, filename, expected_shape, dtype, reference_header): test_image = np.random.rand(*data_shape).astype(dtype) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, filename) @@ -98,9 +103,9 @@ def test_read_with_header(self, data_shape, filename, expected_shape, dtype, re np.testing.assert_allclose(image_array, test_image) self.assertIsInstance(image_header, dict) self.assertTupleEqual(tuple(image_header["spatial_shape"]), expected_shape) - + @parameterized.expand([TEST_CASE_8]) - def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header): + def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header): test_image = np.random.rand(*data_shape).astype(dtype) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, filename) @@ -110,8 +115,8 @@ def test_read_with_header_index_order_c(self, data_shape, filename, expected_sh self.assertIsInstance(image_array, np.ndarray) self.assertEqual(image_array.dtype, dtype) self.assertTupleEqual(image_array.shape, expected_shape[::-1]) - self.assertTupleEqual(image_array.shape, tuple(image_header["spatial_shape"])) - - + self.assertTupleEqual(image_array.shape, tuple(image_header["spatial_shape"])) + + if __name__ == "__main__": unittest.main() From 01e495a14cca373b8b1d72557fb6083d75fd51b1 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 20:04:21 +0200 Subject: [PATCH 18/28] autofix formating Signed-off-by: kbressem --- monai/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e58c26944d..63aa29df65 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,7 +34,7 @@ from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, NrrdReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, From c051993466cc90d72e4d6e3663a9efa341eb21cb Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 12 May 2022 20:28:23 +0200 Subject: [PATCH 19/28] change NrrdImage class to namedtuple and make flake8 happy Signed-off-by: kbressem --- monai/data/image_reader.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 3ac0491b18..9e5fc60268 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,6 +11,7 @@ import warnings from abc import ABC, abstractmethod +from collections import namedtuple from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -980,14 +981,6 @@ def _extract_patches( return flat_patch_grid -class NrrdImage: - "Wrapper for image array and header" - - def __init__(self, array: np.ndarray, header: dict) -> None: - self.array = array - self.header = header - - @require_pkg(pkg_name="nrrd") class NrrdReader(ImageReader): """ @@ -1027,7 +1020,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: if a list of files, verify all the suffixes. """ - suffixes: Sequencec[str] = ["nrrd", "seg.nrrd"] + suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] return has_nrrd and is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: @@ -1045,21 +1038,24 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) + nrrd_image = namedtuple("nrrd_image", ["array", "header"]) + array, header = nrrd.read(name, index_order=self.index_order, *kwargs_) + nrrd_image.array = array + nrrd_image.header = header img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]: + def get_data(self, img: Union[namedtuple, List[namedtuple]]) -> Tuple[np.ndarray, Dict]: """ Extract data array and meta data from loaded image and return them. This function must return two objects, the first is a numpy array of image data, the second is a dictionary of meta data. Args: - img: an `NrrdImage` object loaded from an image file or a list of image objects. + img: a nrrd image loaded from an image file or a list of image objects. """ - img_array: List[NrrdImage] = [] + img_array: List[namedtuple] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): @@ -1082,23 +1078,23 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, return _stack_images(img_array, compatible_meta), compatible_meta - def _get_array_data(self, img: NrrdImage) -> np.ndarray: + def _get_array_data(self, img: namedtuple) -> np.ndarray: """ Get the array data as Numpy array of `self.dtype` Args: - img: A `NrrdImage` loaded from image file + img: A nrrd image loaded from image file """ return img.array.astype(self.dtype) - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, img: namedtuple) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. Args: - img: A `NrrdImage` loaded from image file + img: A nrrd image loaded from image file """ direction = img.header["space directions"] @@ -1125,7 +1121,7 @@ def _switch_lps_ras(self, header: dict) -> dict: header["original_affine"] = orientation_ras_lps(header["original_affine"]) return header - def _convert_F_to_C_order(self, header: dict) -> dict: + def _convert_f_to_c_order(self, header: dict) -> dict: """ All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] From b7c3efb5093ece10d5234a45df4c1d75ddd7aed4 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 10:02:32 +0200 Subject: [PATCH 20/28] Add pynrrd to requirements Signed-off-by: kbressem --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 651a99eba9..271e8db9e3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -47,3 +47,4 @@ types-PyYAML pyyaml fire jsonschema +pynrrd From 03b59eac99326d2a8ae61a2fea3b094a5400e5fb Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 10:19:07 +0200 Subject: [PATCH 21/28] correct typing for namedtumple make flake8 happy Signed-off-by: kbressem --- monai/data/image_reader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9e5fc60268..44cae04cc9 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -1045,7 +1045,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img: Union[namedtuple, List[namedtuple]]) -> Tuple[np.ndarray, Dict]: + def get_data(self, img: Union[NamedTuple, List[NamedTuple]]) -> Tuple[np.ndarray, Dict]: """ Extract data array and meta data from loaded image and return them. This function must return two objects, the first is a numpy array of image data, @@ -1055,7 +1055,7 @@ def get_data(self, img: Union[namedtuple, List[namedtuple]]) -> Tuple[np.ndarray img: a nrrd image loaded from an image file or a list of image objects. """ - img_array: List[namedtuple] = [] + img_array: List[NamedTuple] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): @@ -1063,7 +1063,7 @@ def get_data(self, img: Union[namedtuple, List[namedtuple]]) -> Tuple[np.ndarray img_array.append(data) header = dict(i.header) if self.index_order == "C": - header = self._convert_F_to_C_order(header) + header = self._convert_f_to_c_order(header) header["original_affine"] = self._get_affine(i) header = self._switch_lps_ras(header) header["affine"] = header["original_affine"].copy() @@ -1078,7 +1078,7 @@ def get_data(self, img: Union[namedtuple, List[namedtuple]]) -> Tuple[np.ndarray return _stack_images(img_array, compatible_meta), compatible_meta - def _get_array_data(self, img: namedtuple) -> np.ndarray: + def _get_array_data(self, img: NamedTuple) -> np.ndarray: """ Get the array data as Numpy array of `self.dtype` @@ -1088,7 +1088,7 @@ def _get_array_data(self, img: namedtuple) -> np.ndarray: """ return img.array.astype(self.dtype) - def _get_affine(self, img: namedtuple) -> np.ndarray: + def _get_affine(self, img: NamedTuple) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. From 7f8a4c18b2d950ee24d6c44c21c3baee45251433 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 11:31:51 +0200 Subject: [PATCH 22/28] Add pynrrd info to `get_optional_config_values` Changed NrrdImage to dataclass Signed-off-by: kbressem --- monai/config/deviceconfig.py | 1 + monai/data/image_reader.py | 39 ++++++++++++++++-------------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index fd7ca572e6..8d6383ed97 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -75,6 +75,7 @@ def get_optional_config_values(): output["einops"] = get_package_version("einops") output["transformers"] = get_package_version("transformers") output["mlflow"] = get_package_version("mlflow") + output["pynrrd"] = get_package_version("nrrd") return output diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 44cae04cc9..8630d21afd 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,9 +11,9 @@ import warnings from abc import ABC, abstractmethod -from collections import namedtuple +from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -981,6 +981,14 @@ def _extract_patches( return flat_patch_grid +@dataclass +class NrrdImage: + """Class to wrap nrrd image array and metadata header""" + + array: np.ndarray + header: dict + + @require_pkg(pkg_name="nrrd") class NrrdReader(ImageReader): """ @@ -1038,28 +1046,25 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = namedtuple("nrrd_image", ["array", "header"]) - array, header = nrrd.read(name, index_order=self.index_order, *kwargs_) - nrrd_image.array = array - nrrd_image.header = header + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img: Union[NamedTuple, List[NamedTuple]]) -> Tuple[np.ndarray, Dict]: + def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]: """ Extract data array and meta data from loaded image and return them. This function must return two objects, the first is a numpy array of image data, the second is a dictionary of meta data. Args: - img: a nrrd image loaded from an image file or a list of image objects. + img: a `NrrdImage` loaded from an image file or a list of image objects. """ - img_array: List[NamedTuple] = [] + img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): - data = self._get_array_data(i) + data = i.array.astype(self.dtype) img_array.append(data) header = dict(i.header) if self.index_order == "C": @@ -1078,23 +1083,13 @@ def get_data(self, img: Union[NamedTuple, List[NamedTuple]]) -> Tuple[np.ndarray return _stack_images(img_array, compatible_meta), compatible_meta - def _get_array_data(self, img: NamedTuple) -> np.ndarray: - """ - Get the array data as Numpy array of `self.dtype` - - Args: - img: A nrrd image loaded from image file - - """ - return img.array.astype(self.dtype) - - def _get_affine(self, img: NamedTuple) -> np.ndarray: + def _get_affine(self, img: NrrdImage) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. Args: - img: A nrrd image loaded from image file + img: A `NrrdImage` loaded from image file """ direction = img.header["space directions"] From b885c23fb26fa2035c148f6bd7e3cc239cc18d0e Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 12:21:39 +0200 Subject: [PATCH 23/28] exclude test_nrrd_reader.py from min tests Signed-off-by: kbressem --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 6549fdcd4b..f17aaa85b0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -114,6 +114,7 @@ def run_testsuit(): "test_nifti_header_revise", "test_nifti_rw", "test_nifti_saver", + "test_nrrd_reader", "test_occlusion_sensitivity", "test_orientation", "test_orientationd", From 7294d3445d73889b94f118bd5cbc762936d5d584 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 13:15:41 +0200 Subject: [PATCH 24/28] add pynrrd to config files Signed-off-by: kbressem --- docs/requirements.txt | 1 + docs/source/installation.md | 4 ++-- environment-dev.yml | 1 + setup.cfg | 5 ++++- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index f9749e9e36..b7edff27fa 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -28,3 +28,4 @@ tifffile; platform_system == "Linux" pyyaml fire jsonschema +pynrrd diff --git a/docs/source/installation.md b/docs/source/installation.md index 12bf544cba..76c9166566 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -190,9 +190,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, pynrrd] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `pynrrd`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/environment-dev.yml b/environment-dev.yml index a361262930..9eef775b78 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -45,6 +45,7 @@ dependencies: - pyyaml - fire - jsonschema + - pynrrd - pip - pip: # pip for itk as conda-forge version only up to v5.1 diff --git a/setup.cfg b/setup.cfg index 12f974ca6d..d12fa61cef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,7 @@ all = pyyaml fire jsonschema + pynrrd nibabel = nibabel skimage = @@ -101,7 +102,9 @@ fire = fire jsonschema = jsonschema - +pynrrd = + pynrrd + [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 From a5aa7e51acd0645b6a089c737e6b01ce9754888d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 May 2022 11:20:26 +0000 Subject: [PATCH 25/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index d12fa61cef..914e404b2d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,7 @@ jsonschema = jsonschema pynrrd = pynrrd - + [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 From 4330fd83f55950355df8f3374b5605791d42ec95 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 14:19:35 +0200 Subject: [PATCH 26/28] Change the way space is handled in the header. Now, if space is not in header, it is assumed to be LPS and converted to RAS. If space is defined and not LPS, nothing is done to prevent wrong conversions. Signed-off-by: kbressem --- monai/data/image_reader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 8630d21afd..da4c8e6a43 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1105,13 +1105,15 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: def _switch_lps_ras(self, header: dict) -> dict: """ For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and - `space` argument in header accordingly. + `space` argument in header accordingly. If no information of space is given in the header, + LPS is assumed and thus converted to RAS. If information about space is given, + but is not LPS, the unchanged header is returned. Args: header: The image meta data as dict """ - if header["space"] == "left-posterior-superior": + if "space" not in header or header["space"] == "left-posterior-superior": header["space"] = "right-anterior-superior" header["original_affine"] = orientation_ras_lps(header["original_affine"]) return header From ce9c88347b8f33a27a8b0acf791760d0c95ce269 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 14:21:28 +0200 Subject: [PATCH 27/28] add `TestLoadSaveNrrd` where it is tested if a nrrd file, created by ITKWriter can be loaded again. 2D and 3D files with no channels are tested Signed-off-by: kbressem --- tests/test_image_rw.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index 62b1147aa5..404f8b66c6 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -18,7 +18,7 @@ import numpy as np from parameterized import parameterized -from monai.data.image_reader import ITKReader, NibabelReader, PILReader +from monai.data.image_reader import ITKReader, NibabelReader, PILReader, NrrdReader from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer from monai.transforms import LoadImage, SaveImage, moveaxis from monai.utils import OptionalImportError @@ -132,5 +132,38 @@ def test_1_new(self): self.assertEqual(resolve_writer("new")[0](0), 1) +class TestLoadSaveNrrd(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def nrrd_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) + for p in TEST_NDARRAYS: + output_ext = ".nrrd" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver(p(test_data), {"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + loader = LoadImage(reader=reader) + data, meta = loader(saved_path) + assert_allclose(data, test_data) + + @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.random.randn(8,8).astype(np.float32) + self.nrrd_rw(test_data, reader, writer, np.float32) + + @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter])) + def test_3d(self, reader, writer): + test_data = np.random.randn(8,8,8).astype(np.float32) + self.nrrd_rw(test_data, reader, writer, np.float32) + + if __name__ == "__main__": unittest.main() From 7417d898637a02c8642dcdb9129069f5e015c100 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 13 May 2022 14:23:22 +0200 Subject: [PATCH 28/28] autofix format Signed-off-by: kbressem --- monai/data/image_reader.py | 2 +- tests/test_image_rw.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index da4c8e6a43..af098c0fa3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1107,7 +1107,7 @@ def _switch_lps_ras(self, header: dict) -> dict: For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and `space` argument in header accordingly. If no information of space is given in the header, LPS is assumed and thus converted to RAS. If information about space is given, - but is not LPS, the unchanged header is returned. + but is not LPS, the unchanged header is returned. Args: header: The image meta data as dict diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index 404f8b66c6..7975349109 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -18,7 +18,7 @@ import numpy as np from parameterized import parameterized -from monai.data.image_reader import ITKReader, NibabelReader, PILReader, NrrdReader +from monai.data.image_reader import ITKReader, NibabelReader, NrrdReader, PILReader from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer from monai.transforms import LoadImage, SaveImage, moveaxis from monai.utils import OptionalImportError @@ -147,7 +147,7 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True): filepath = f"testfile_{ndim}d" saver = SaveImage( output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer - ) + ) saver(p(test_data), {"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}) saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) loader = LoadImage(reader=reader) @@ -156,12 +156,12 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True): @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter])) def test_2d(self, reader, writer): - test_data = np.random.randn(8,8).astype(np.float32) + test_data = np.random.randn(8, 8).astype(np.float32) self.nrrd_rw(test_data, reader, writer, np.float32) @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter])) def test_3d(self, reader, writer): - test_data = np.random.randn(8,8,8).astype(np.float32) + test_data = np.random.randn(8, 8, 8).astype(np.float32) self.nrrd_rw(test_data, reader, writer, np.float32)