From fac754d2c30435d3ba974bed1927aff8892e77c5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 5 Dec 2023 07:55:16 -0500 Subject: [PATCH 01/32] 6676 port generative networks autoencoderkl (#7260) Partially fixes #6676 ### Description Implements the AutoencoderKL network from MONAI Generative. NB this network is subject to a planned refactor once the porting is complete, [see here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 807 +++++++++++++++++++++++++++ tests/test_autoencoderkl.py | 276 +++++++++ 4 files changed, 1089 insertions(+) create mode 100644 monai/networks/nets/autoencoderkl.py create mode 100644 tests/test_autoencoderkl.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..dbfdf35784 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -595,6 +595,11 @@ Nets .. autoclass:: AutoEncoder :members: +`AutoEncoderKL` +~~~~~~~~~~~~~~~ +.. autoclass:: AutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..ea08246d25 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,6 +14,7 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..9a9f35d5ae --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,807 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Sequence +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution + +# To install xformers, use pip install xformers==0.0.16rc401 +from monai.utils import ensure_tuple_rep, optional_import + +xformers, has_xformers = optional_import("xformers") + +__all__ = ["AutoencoderKL"] + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + conv: torch.Tensor = self.conv(x) + return conv + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + x = self.conv(x) + return x + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class _ResBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Attention block. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks: List[nn.Module] = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(channels)) + + blocks: List[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + channels: number of output channels for each block. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpoint: if True, use activation checkpoint to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpoint: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpoint = use_checkpoint + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpoint: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor + if self.use_checkpoint: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..448f1e8e9a --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,276 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL +from tests.utils import SkipIfBeforePyTorchVersion + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + +if __name__ == "__main__": + unittest.main() From b3fdfdd2111c5d1349a345fbd4e24c570d1fb690 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 6 Dec 2023 22:36:50 -0500 Subject: [PATCH 02/32] 6676 port generative networks vqvae (#7285) Partially fixes https://github.com/Project-MONAI/MONAI/issues/6676 ### Description Implements the VQ-VAE network, including the vector quantizer block, from MONAI Generative. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/networks.rst | 13 + monai/bundle/scripts.py | 2 +- monai/networks/layers/__init__.py | 1 + monai/networks/layers/vector_quantizer.py | 233 +++++++++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 14 +- monai/networks/nets/vqvae.py | 466 ++++++++++++++++++++++ tests/test_vector_quantizer.py | 89 +++++ tests/test_vqvae.py | 274 +++++++++++++ 9 files changed, 1085 insertions(+), 8 deletions(-) create mode 100644 monai/networks/layers/vector_quantizer.py create mode 100644 monai/networks/nets/vqvae.py create mode 100644 tests/test_vector_quantizer.py create mode 100644 tests/test_vqvae.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index dbfdf35784..d8be26264b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,7 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift + Layers ------ @@ -408,6 +409,13 @@ Layers .. autoclass:: LLTM :members: +`Vector Quantizer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer + :members: +.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -728,6 +736,11 @@ Nets .. autoclass:: voxelmorph +`VQ-VAE` +~~~~~~~~ +.. autoclass:: VQVAE + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 20a491e493..2565a3cf64 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -221,7 +221,7 @@ def _download_from_ngc( def _get_latest_bundle_version_monaihosting(name): url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name}" + full_url = f"{url}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..bd3e3af3af 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,4 +37,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..9c354e1009 --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims: number of spatial dimensions of the input. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors of shape [B, C, H, W, D]. + + Returns: + torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,H,W,D,1] + + """ + with torch.cuda.amp.autocast(enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.cuda.amp.autocast(enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ea08246d25..db3c77c717 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -113,3 +113,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 9a9f35d5ae..f7ae77f056 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -38,7 +38,7 @@ class _Upsample(nn.Module): Convolution-based upsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels to the layer. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -98,7 +98,7 @@ class _Downsample(nn.Module): Convolution-based downsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. """ @@ -132,7 +132,7 @@ class _ResBlock(nn.Module): residual connection between input and output. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: input channels to the layer. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. @@ -206,7 +206,7 @@ class _AttentionBlock(nn.Module): Attention block. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of @@ -325,7 +325,7 @@ class Encoder(nn.Module): Convolutional cascade that downsamples the image into a spatial latent space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. @@ -463,7 +463,7 @@ class Decoder(nn.Module): Convolutional cascade upsampling from a spatial latent space into an image space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. @@ -611,7 +611,7 @@ class AutoencoderKL(nn.Module): and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..d4771e203a --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,466 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Tuple + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + in_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.in_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + channels: sequence containing the number of channels at each level of the encoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Tuple[int, int, int, int]], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks: list[nn.Module] = [] + + for i in range(len(self.channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.channels[i - 1], + out_channels=self.channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.channels[len(self.channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + channels: sequence containing the number of channels at each level of the decoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Tuple[int, int, int, int, int]], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.channels)) + + blocks: list[nn.Module] = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Tuple[int, int, int, int]] + | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] + | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.channels = channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) + + if len(num_res_channels) != len(channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channls`." + ) + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) + else: + upsample_parameters_tuple = upsample_parameters + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) + else: + downsample_parameters_tuple = downsample_parameters + + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + for parameter in downsample_parameters_tuple: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters_tuple: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters_tuple) != len(channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters_tuple) != len(channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters_tuple, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters_tuple, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + output = self.encoder(images) + return output + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + output = self.decoder(quantizations) + return output + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..43533d0377 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from math import prod + +import torch +from parameterized import parameterized + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + +TEST_CASES = [ + [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], + [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], +] + + +class TestEMA(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_ema_shape(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + @parameterized.expand(TEST_CASES) + def test_ema_quantize(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) + self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) + self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, output_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..4916dc2faa --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,274 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main() From c61c6ac2d56af08fbbfb955324b8639e266a25db Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 11 Dec 2023 11:10:20 -0500 Subject: [PATCH 03/32] 6676 port generative networks transformer (#7300) Towards #6676 . ### Description Adds a simple decoder-only transformer architecture. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/transformer.py | 314 +++++++++++++++++++++++++++++ tests/test_transformer.py | 73 +++++++ 4 files changed, 393 insertions(+) create mode 100644 monai/networks/nets/transformer.py create mode 100644 tests/test_transformer.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d8be26264b..06f60fe8af 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -613,6 +613,11 @@ Nets .. autoclass:: VarAutoEncoder :members: +`DecoderOnlyTransformer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DecoderOnlyTransformer + :members: + `ViT` ~~~~~ .. autoclass:: ViT diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index db3c77c717..08384b4d52 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,7 @@ from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..b742c12205 --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,314 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.mlp import MLPBlock +from monai.utils import optional_import + +xops, has_xformers = optional_import("xformers.ops") +__all__ = ["DecoderOnlyTransformer"] + + +class _SABlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + # key, query, value projections + self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query = self.to_q(x) + + kv = context if context is not None else x + _, kv_t, _ = kv.size() + key = self.to_k(kv) + value = self.to_v(kv) + + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + y: torch.Tensor + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) + + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y + + +class _TransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + self.with_cross_attention = with_cross_attention + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=causal, + sequence_length=sequence_length, + use_flash_attention=use_flash_attention, + ) + + if self.with_cross_attention: + self.norm2 = nn.LayerNorm(hidden_size) + self.cross_attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + with_cross_attention=with_cross_attention, + causal=False, + use_flash_attention=use_flash_attention, + ) + self.norm3 = nn.LayerNorm(hidden_size) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm2(x), context=context) + x = x + self.mlp(self.norm3(x)) + return x + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + embedding: torch.Tensor = self.embedding(positions) + return embedding + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + _TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + logits: torch.Tensor = self.to_logits(x) + return logits diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..ea6ebdf50f --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer + +TEST_CASES = [] +for dropout_rate in np.linspace(0, 1, 2): + for attention_layer_dim in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + TEST_CASES.append( + [ + { + "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + } + ] + ) + + +class TestDecoderOnlyTransformer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_unconditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + @parameterized.expand(TEST_CASES) + def test_conditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 + ) + + def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + embedding_dropout_rate=-1, + ) + + +if __name__ == "__main__": + unittest.main() From de0a4760547eabe0337d9d9cf40fc90f6bb1cb59 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 02:09:37 -0500 Subject: [PATCH 04/32] 6676 port generative networks ddpm (#7304) Towards #6676 . ### Description Adds a DDPM unet. Refactoring for some of the blocks here is scheduled [here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/diffusion_model_unet.py | 2138 +++++++++++++++++++ tests/test_diffusion_model_unet.py | 535 +++++ 4 files changed, 2679 insertions(+) create mode 100644 monai/networks/nets/diffusion_model_unet.py create mode 100644 tests/test_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 06f60fe8af..417fb8ac73 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -583,6 +583,11 @@ Nets .. autoclass:: VNet :members: +`DiffusionModelUnet` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionModelUNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 08384b4d52..31fbd73b4e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -35,6 +35,7 @@ densenet201, densenet264, ) +from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py new file mode 100644 index 0000000000..1532215c70 --- /dev/null +++ b/monai/networks/nets/diffusion_model_unet.py @@ -0,0 +1,2138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 + +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["DiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class _CrossAttention(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + output: torch.Tensor = self.to_out(x) + return output + + +class _BasicTransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = _CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = _CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class _SpatialTransformer(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + _BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + output: torch.Tensor = self.op(x) + return output + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class _ResnetBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsampler: nn.Module | None + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] + + is_final_block = i == len(channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += [down_block_res_sample] + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output: torch.Tensor = self.out(h) + + return output diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py new file mode 100644 index 0000000000..d40a31a1da --- /dev/null +++ b/tests/test_diffusion_model_unet.py @@ -0,0 +1,535 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + +if __name__ == "__main__": + unittest.main() From 43bc0230f8dd042924ccb6267317622fdffc695e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 22:37:51 -0500 Subject: [PATCH 05/32] 6676 port generative networks controlnet (#7312) Part of #6676 . ### Description Ports the ControlNet. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/controlnet.py | 421 ++++++++++++++++++++++++++++++ tests/test_controlnet.py | 177 +++++++++++++ 4 files changed, 604 insertions(+) create mode 100644 monai/networks/nets/controlnet.py create mode 100644 tests/test_controlnet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 417fb8ac73..0960fcdbc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -588,6 +588,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`ControlNet` +~~~~~~~~~~~~ +.. autoclass:: ControlNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 31fbd73b4e..58cb652bae 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..d98755f401 --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,421 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(channels) - 1): + channel_in = channels[i] + channel_out = channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError( + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={channels} and norm_num_groups={norm_num_groups}" + ) + + if len(channels) != len(attention_levels): + raise ValueError( + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={channels} and attention_levels={attention_levels}" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " + f"attention_levels={attention_levels} . For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + channels=conditioning_embedding_num_channels, + out_channels=channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, H, W, [D]). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) + conditioning_scale: conditioning scale. + context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = [] + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..07dfa2e49b --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 4, + }, + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "num_head_channels": 4, + "attention_levels": (False, False, False), + "norm_num_groups": 4, + "resblock_updown": True, + }, + (1, 4, 4, 4, 4), + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + (1, 8, 4, 4), + ], +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @parameterized.expand(COND_CASES_2D) + def test_shape_conditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + +if __name__ == "__main__": + unittest.main() From b85a534a32a9d969aba7b6ba752d1f2486c42177 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 14 Dec 2023 06:06:59 -0500 Subject: [PATCH 06/32] Adds patchgan discriminator (#7319) Part of #6676 . ### Description Adds a patchgan-style discriminator, both single scale and multiscale. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 8 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/patchgan_discriminator.py | 247 ++++++++++++++++++ tests/test_patch_gan_dicriminator.py | 179 +++++++++++++ 4 files changed, 435 insertions(+) create mode 100644 monai/networks/nets/patchgan_discriminator.py create mode 100644 tests/test_patch_gan_dicriminator.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 0960fcdbc0..8e79298941 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -756,6 +756,14 @@ Nets .. autoclass:: VQVAE :members: +`PatchGANDiscriminator` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchDiscriminator + :members: + +.. autoclass:: MultiScalePatchDiscriminator + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 58cb652bae..0f0d033d63 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -55,6 +55,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..3b089616ce --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,247 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images + at different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first + discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: probability of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." + "Please reduce num_layers, reduce num_D or enter bigger images." + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + Args: + i: Input tensor + + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. + kernel_size: kernel size of the convolution layers + act: activation type and arguments. Defaults to LeakyReLU. + norm: feature normalization type and arguments. Defaults to batch norm. + bias: whether to have a bias term in convolution blocks. Defaults to False. + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = channels + output_channels = channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(self.initialise_weights) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + x: input tensor + + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] + + def initialise_weights(self, m: nn.Module) -> None: + """ + Initialise weights of Convolution and BatchNorm layers. + + Args: + m: instance of torch.nn.module (or of class inheriting torch.nn.module) + """ + classname = m.__class__.__name__ + if classname.find("Conv2d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv3d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv1d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py new file mode 100644 index 0000000000..c19898e70d --- /dev/null +++ b/tests/test_patch_gan_dicriminator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator +from tests.utils import test_script_save + +TEST_PATCHGAN = [ + [ + { + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512]), + (1, 8, 128, 256), + (1, 1, 32, 64), + ], + [ + { + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512, 256]), + (1, 8, 128, 256, 128), + (1, 1, 32, 64, 32), + ], +] + +TEST_MULTISCALE_PATCHGAN = [ + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], + ], + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], + ], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(TEST_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): + net = PatchDiscriminator(**input_param) + with eval_mode(net): + result = net.forward(input_data) + self.assertEqual(tuple(result[0].shape), expected_shape_feature) + self.assertEqual(tuple(result[-1].shape), expected_shape_output) + + def test_script(self): + net = PatchDiscriminator( + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +class TestMultiscalePatchGAN(unittest.TestCase): + @parameterized.expand(TEST_MULTISCALE_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() From aa4a4dbd3619dd443681c4688bd48ce1bea9b85d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 19 Dec 2023 03:21:32 +0000 Subject: [PATCH 07/32] 6676 port generative networks spade (#7320) Towards #6676 . ### Description This adds SPADE-enabled autoencoder and diffusion_model_unet architectures. They are new implementations for each network, rather than options in the existing network, because @virginiafdez and I felt that adding additional options to the existing networks to enable spade compatibility significantly reduced the readability of them for users who were not interested in SPADE functionality. These are the last networks to be ported over. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/networks.rst | 14 + monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/spade_norm.py | 96 ++ monai/networks/nets/__init__.py | 2 + monai/networks/nets/spade_autoencoderkl.py | 473 +++++++++ .../nets/spade_diffusion_model_unet.py | 908 ++++++++++++++++++ test_spade_autoencoderkl.py | 260 +++++ test_spade_diffusion_model_unet.py | 558 +++++++++++ 8 files changed, 2312 insertions(+) create mode 100644 monai/networks/blocks/spade_norm.py create mode 100644 monai/networks/nets/spade_autoencoderkl.py create mode 100644 monai/networks/nets/spade_diffusion_model_unet.py create mode 100644 test_spade_autoencoderkl.py create mode 100644 test_spade_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8e79298941..79d5ef822e 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,10 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift +`SPADE` +~~~~~~~ +.. autoclass:: monai.networks.blocks.spade_norm.SPADE + :members: Layers ------ @@ -588,6 +592,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`SPADEDiffusionModelUNet` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEDiffusionModelUNet + :members: + `ControlNet` ~~~~~~~~~~~~ .. autoclass:: ControlNet @@ -618,6 +627,11 @@ Nets .. autoclass:: AutoencoderKL :members: +`SPADEAutoencoderKL` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEAutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index e67cb3376f..afb6664bd9 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -30,6 +30,7 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .spade_norm import SPADE from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py new file mode 100644 index 0000000000..b1046f3154 --- /dev/null +++ b/monai/networks/blocks/spade_norm.py @@ -0,0 +1,96 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ADN, Convolution + + +class SPADE(nn.Module): + """ + Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a + semantic map. This block is used in SPADE-based image-to-image translation models, as described in + Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291). + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict | None = None, + ) -> None: + super().__init__() + + if norm_params is None: + norm_params = {} + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels. + segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out: torch.Tensor = normalized * (1 + gamma) + beta + return out diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0f0d033d63..a7ce16ad64 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,8 @@ seresnext50, seresnext101, ) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py new file mode 100644 index 0000000000..e064c19740 --- /dev/null +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -0,0 +1,473 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample +from monai.utils import ensure_tuple_rep + +__all__ = ["SPADEAutoencoderKL"] + + +class SPADEResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h, seg) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class SPADEDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from channels contain an attention block. + label_nc: number of semantic channels for SPADE normalisation. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + label_nc: int, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.label_nc = label_nc + + reversed_block_out_channels = list(reversed(channels)) + + blocks: list[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + if isinstance(block, SPADEResBlock): + x = block(x, seg) + else: + x = block(x) + return x + + +class SPADEAutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + label_nc: number of semantic channels for SPADE normalisation. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + label_nc: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + label_nc=label_nc, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + h = self.encoder(x) + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu, seg) + return reconstruction + + def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor = self.decoder(z, seg) + return dec + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z, seg) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + image = self.decode(z, seg) + return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py new file mode 100644 index 0000000000..d53327100e --- /dev/null +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -0,0 +1,908 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.diffusion_model_unet import ( + _AttentionBlock, + _Downsample, + _ResnetBlock, + _SpatialTransformer, + _Upsample, + get_down_block, + get_mid_block, + get_timestep_embedding, + zero_module, +) +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["SPADEDiffusionModelUNet"] + + +class SPADEResnetBlock(nn.Module): + """ + Residual block with timestep conditioning and SPADE norm. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + label_nc: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=self.out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h, seg) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class SPADEUpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADEAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADECrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor | None = None, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_spade_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + label_nc: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, +) -> nn.Module: + if with_attn: + return SPADEAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + elif with_cross_attn: + return SPADECrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + else: + return SPADEUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + spade_intermediate_channels=spade_intermediate_channels, + ) + + +class SPADEDiffusionModelUNet(nn.Module): + """ + UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for + semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at + https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + self.label_nc = label_nc + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_spade_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [h] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output diff --git a/test_spade_autoencoderkl.py b/test_spade_autoencoderkl.py new file mode 100644 index 0000000000..6675a6db67 --- /dev/null +++ b/test_spade_autoencoderkl.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEAutoencoderKL + +CASES = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "spade_intermediate_channels": 32, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], +] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class TestSPADEAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_encode(self): + input_param, input_shape, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_wrong_shape_decode(self): + net = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, False), + num_res_blocks=1, + norm_num_groups=4, + ) + with self.assertRaises(RuntimeError): + _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test_spade_diffusion_model_unet.py b/test_spade_diffusion_model_unet.py new file mode 100644 index 0000000000..c8a2103cf6 --- /dev/null +++ b/test_spade_diffusion_model_unet.py @@ -0,0 +1,558 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEDiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + "spade_intermediate_channels": 256, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + "label_nc": 3, + } + ], +] + + +class TestSPADEDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) + ) + + def test_label_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(RuntimeError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, 3, 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 16, 16)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() From 3447b09435e72e856a4b29b436c2fbe61159a42f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 3 Jan 2024 16:30:24 +0000 Subject: [PATCH 08/32] 6676 port diffusion schedulers (#7332) Towards #6676 . ### Description This adds some base classes for DDPM noise schedulers + three scheduler types. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 20 ++ monai/networks/schedulers/__init__.py | 17 ++ monai/networks/schedulers/ddim.py | 284 ++++++++++++++++++++++ monai/networks/schedulers/ddpm.py | 243 +++++++++++++++++++ monai/networks/schedulers/pndm.py | 316 +++++++++++++++++++++++++ monai/networks/schedulers/scheduler.py | 203 ++++++++++++++++ monai/utils/misc.py | 4 +- tests/test_scheduler_ddim.py | 83 +++++++ tests/test_scheduler_ddpm.py | 104 ++++++++ tests/test_scheduler_pndm.py | 108 +++++++++ 10 files changed, 1380 insertions(+), 2 deletions(-) create mode 100644 monai/networks/schedulers/__init__.py create mode 100644 monai/networks/schedulers/ddim.py create mode 100644 monai/networks/schedulers/ddpm.py create mode 100644 monai/networks/schedulers/pndm.py create mode 100644 monai/networks/schedulers/scheduler.py create mode 100644 tests/test_scheduler_ddim.py create mode 100644 tests/test_scheduler_ddpm.py create mode 100644 tests/test_scheduler_pndm.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 79d5ef822e..f9375f1e97 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -778,6 +778,26 @@ Nets .. autoclass:: MultiScalePatchDiscriminator :members: +Diffusion Schedulers +-------------------- +.. autoclass:: monai.networks.schedulers.Scheduler + :members: + +`DDPM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDPMScheduler + :members: + +`DDIM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDIMScheduler + :members: + +`PNDM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.PNDMScheduler + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py new file mode 100644 index 0000000000..29e9020d65 --- /dev/null +++ b/monai/networks/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ddim import DDIMScheduler +from .ddpm import DDPMScheduler +from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py new file mode 100644 index 0000000000..ec47ff8dc6 --- /dev/null +++ b/monai/networks/schedulers/ddim.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py new file mode 100644 index 0000000000..a5173a1b65 --- /dev/null +++ b/monai/networks/schedulers/ddpm.py @@ -0,0 +1,243 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py new file mode 100644 index 0000000000..c0728bbdff --- /dev/null +++ b/monai/networks/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py new file mode 100644 index 0000000000..17bb526abc --- /dev/null +++ b/monai/networks/schedulers/scheduler.py @@ -0,0 +1,203 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..4f2501a7ee 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py new file mode 100644 index 0000000000..1a8f8cab67 --- /dev/null +++ b/tests/test_scheduler_ddim.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDIMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py new file mode 100644 index 0000000000..f0447aded2 --- /dev/null +++ b/tests/test_scheduler_ddpm.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDPMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_2D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_3D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDPMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + + def test_step_learned(self): + for variance_type in ["learned", "learned_range"]: + scheduler = DDPMScheduler(variance_type=variance_type) + model_output = torch.randn(2, 6, 16, 16) + sample = torch.randn(2, 3, 16, 16) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, sample.shape) + self.assertEqual(output_step[1].shape, sample.shape) + + def test_set_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py new file mode 100644 index 0000000000..69e5e403f5 --- /dev/null +++ b/tests/test_scheduler_pndm.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import PNDMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [ + {"schedule": "linear_beta"}, + (1, 1, 2, 2), + torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), + ] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(600) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1], None) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_FULl_LOOP) + def test_timestep_two_loops(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + torch.manual_seed(42) + model_output2 = torch.randn(input_shape) + sample2 = torch.randn(input_shape) + scheduler.set_timesteps(50) + for t in range(50): + sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) + assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_prk(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 109) + self.assertEqual(len(scheduler.timesteps), 109) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() From 3ab5c62b9e0964d129980f7970aabefa9d0e2d2f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 5 Jan 2024 06:52:15 +0000 Subject: [PATCH 09/32] 6676 port diffusion schedulers (#7364) This is an update to PR https://github.com/Project-MONAI/MONAI/pull/7332 - I addressed the comments but failed to push the changes before it was merged! Changes are very minor. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/networks/schedulers/ddim.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index ec47ff8dc6..78e3cc2a0c 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -34,23 +34,10 @@ import numpy as np import torch -from monai.utils import StrEnum - +from .ddpm import DDPMPredictionType from .scheduler import Scheduler - -class DDIMPredictionType(StrEnum): - """ - Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - sample: directly predicting the noisy sample - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - SAMPLE = "sample" - V_PREDICTION = "v_prediction" +DDIMPredictionType = DDPMPredictionType class DDIMScheduler(Scheduler): @@ -126,6 +113,13 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) @@ -159,7 +153,6 @@ def step( timestep: current discrete timestep in the diffusion chain. sample: current instance of sample being created by diffusion process. eta: weight of noise for added noise in diffusion step. - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. Returns: From 0a549fe937ab63b86ddc741dec8cddcd1085db4d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 8 Jan 2024 14:56:58 +0000 Subject: [PATCH 10/32] Adds ordering util (#7369) Towards #6676 . ### Description This ordering util got missed out my previous PR for the Generative utils. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/utils.rst | 5 + monai/utils/ordering.py | 207 ++++++++++++++++++++++++++ tests/test_ordering.py | 318 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 530 insertions(+) create mode 100644 monai/utils/ordering.py create mode 100644 tests/test_ordering.py diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 527247799f..fef671e1f8 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -81,3 +81,8 @@ Component store --------------- .. autoclass:: monai.utils.component_store.ComponentStore :members: + +Ordering +-------- +.. automodule:: monai.utils.ordering + :members: diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py new file mode 100644 index 0000000000..1be61f98ab --- /dev/null +++ b/monai/utils/ordering.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + +from monai.utils.enums import OrderingTransformations, OrderingType + + +class Ordering: + """ + Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with + one of the following transformations: + Reflection (see np.flip for more details). + Transposition (see np.transpose for more details). + 90-degree rotation (see np.rot90 for more details). + + The transformations are applied in the order specified by the transformation_order parameter. + + Args: + ordering_type: The ordering type. One of the following: + - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from + top to bottom. Also called a row major ordering. + - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like + pattern from top left towards right gowing in a spiral towards the center. + - random': The image is projected into a 1D sequence by randomly shuffling the image. + spatial_dims: The number of spatial dimensions of the image. + dimensions: The dimensions of the image. + reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. + transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. + rot90_axes: A tuple of tuples indicating the axes to rotate the image along. + transformation_order: The order in which to apply the transformations. + """ + + def __init__( + self, + ordering_type: str, + spatial_dims: int, + dimensions: tuple[int, int, int] | tuple[int, int, int, int], + reflected_spatial_dims: tuple[bool, bool] | None = None, + transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None, + rot90_axes: tuple[tuple[int, int], ...] | None = None, + transformation_order: tuple[str, ...] = ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + ) -> None: + super().__init__() + self.ordering_type = ordering_type + + if self.ordering_type not in list(OrderingType): + raise ValueError( + f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." + ) + + self.spatial_dims = spatial_dims + self.dimensions = dimensions + + if len(dimensions) != self.spatial_dims + 1: + raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") + + self.reflected_spatial_dims = reflected_spatial_dims + self.transpositions_axes = transpositions_axes + self.rot90_axes = rot90_axes + if len(set(transformation_order)) != len(transformation_order): + raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") + + for transformation in transformation_order: + if transformation not in list(OrderingTransformations): + raise ValueError( + f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." + ) + self.transformation_order = transformation_order + + self.template = self._create_template() + self._sequence_ordering = self._create_ordering() + self._revert_sequence_ordering = np.argsort(self._sequence_ordering) + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = x[self._sequence_ordering] + + return x + + def get_sequence_ordering(self) -> np.ndarray: + return self._sequence_ordering + + def get_revert_sequence_ordering(self) -> np.ndarray: + return self._revert_sequence_ordering + + def _create_ordering(self) -> np.ndarray: + self.template = self._transform_template() + order = self._order_template(template=self.template) + + return order + + def _create_template(self) -> np.ndarray: + spatial_dimensions = self.dimensions[1:] + template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) + + return template + + def _transform_template(self) -> np.ndarray: + for transformation in self.transformation_order: + if transformation == OrderingTransformations.TRANSPOSE.value: + self.template = self._transpose_template(template=self.template) + elif transformation == OrderingTransformations.ROTATE_90.value: + self.template = self._rot90_template(template=self.template) + elif transformation == OrderingTransformations.REFLECT.value: + self.template = self._flip_template(template=self.template) + + return self.template + + def _transpose_template(self, template: np.ndarray) -> np.ndarray: + if self.transpositions_axes is not None: + for axes in self.transpositions_axes: + template = np.transpose(template, axes=axes) + + return template + + def _flip_template(self, template: np.ndarray) -> np.ndarray: + if self.reflected_spatial_dims is not None: + for axis, to_reflect in enumerate(self.reflected_spatial_dims): + template = np.flip(template, axis=axis) if to_reflect else template + + return template + + def _rot90_template(self, template: np.ndarray) -> np.ndarray: + if self.rot90_axes is not None: + for axes in self.rot90_axes: + template = np.rot90(template, axes=axes) + + return template + + def _order_template(self, template: np.ndarray) -> np.ndarray: + depths = None + if self.spatial_dims == 2: + rows, columns = template.shape[0], template.shape[1] + else: + rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) + + sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + + ordering = np.array([template[tuple(e)] for e in sequence]) + + return ordering + + @staticmethod + def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths is not None: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) + for c in col_idx: + if depths: + depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) + + for d in depth_idx: + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + np.random.shuffle(idx_np) + + return idx_np diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000000..0c52dba5e5 --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,318 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.utils.enums import OrderingTransformations, OrderingType +from monai.utils.ordering import Ordering + +TEST_2D_NON_RANDOM = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 0, 1], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 1, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 1, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 3, 1], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 0, 2], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 2, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], +] + +TEST_2D_RANDOM = [ + [ + { + "ordering_type": OrderingType.RANDOM, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [[0, 1, 2, 3], [0, 1, 3, 2]], + ] +] + +TEST_3D = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 3, + "dimensions": (1, 2, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3, 4, 5, 6, 7], + ] +] + +TEST_ORDERING_TYPE_FAILURE = [ + [ + { + "ordering_type": "hilbert", + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + +TEST_ORDERING_TRANSFORMATION_FAILURE = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + "flip", + ), + } + ] +] + +TEST_REVERT = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + + +class TestOrdering(unittest.TestCase): + @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) + def test_ordering(self, input_param, expected_sequence_ordering): + ordering = Ordering(**input_param) + self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) + + @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) + def test_ordering_type_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) + def test_ordering_transformation_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_2D_RANDOM) + def test_random(self, input_param, not_in_expected_sequence_ordering): + ordering = Ordering(**input_param) + + not_in = [ + np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) + for sequence in not_in_expected_sequence_ordering + ] + + self.assertFalse(np.any(not_in)) + + @parameterized.expand(TEST_REVERT) + def test_revert(self, input_param): + sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() + + ordering = Ordering(**input_param) + + reverted_sequence = sequence[ordering.get_sequence_ordering()] + reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] + + self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() From 510f7bc1eb4505d61f9aec6a9a96c444051a3e45 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 18 Jan 2024 12:38:46 +0000 Subject: [PATCH 11/32] 6676 port generative inferers (#7379) Part of #6676 . ### Description Adds Inferers to assist with training and sampling from diffusion models and controllers. Also takes the opportunity to make two changes which slipped through the previous PRs: - rename the `num_channels` arg in the spade diffusion unet to `channels` to be consistent with all the other models added from Generative - this slipped through in the networks PR. - add the `Ordering` class to `__init__.py` for easier import ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/inferers.rst | 23 + monai/inferers/__init__.py | 5 + monai/inferers/inferer.py | 1280 ++++++++++++++++- monai/networks/nets/diffusion_model_unet.py | 6 +- .../nets/spade_diffusion_model_unet.py | 40 +- monai/utils/__init__.py | 1 + setup.cfg | 14 +- tests/test_controlnet_inferers.py | 1270 ++++++++++++++++ tests/test_diffusion_inferer.py | 226 +++ tests/test_flexible_unet.py | 2 +- tests/test_invertd.py | 12 +- tests/test_latent_diffusion_inferer.py | 796 ++++++++++ tests/test_ordering.py | 29 - .../test_spade_autoencoderkl.py | 0 .../test_spade_diffusion_model_unet.py | 66 +- tests/test_vqvaetransformer_inferer.py | 284 ++++ 16 files changed, 3955 insertions(+), 99 deletions(-) create mode 100644 tests/test_controlnet_inferers.py create mode 100644 tests/test_diffusion_inferer.py create mode 100644 tests/test_latent_diffusion_inferer.py rename test_spade_autoencoderkl.py => tests/test_spade_autoencoderkl.py (100%) rename test_spade_diffusion_model_unet.py => tests/test_spade_diffusion_model_unet.py (92%) create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 33f9e14d83..326f56e96c 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -49,6 +49,29 @@ Inferers :members: :special-members: __call__ +`DiffusionInferer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionInferer + :members: + :special-members: __call__ + +`LatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LatentDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetLatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetLatentDiffusionInferer + :members: + :special-members: __call__ Splitters --------- diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 960380bfb8..fc78b9f7c4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -12,13 +12,18 @@ from __future__ import annotations from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, Inferer, + LatentDiffusionInferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, + VQVAETransformerInferer, ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0b4199938d..72bcb8fd5a 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,24 +11,41 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import partial from pydoc import locate from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from monai.apps.utils import get_logger +from monai.data import decollate_batch from monai.data.meta_tensor import MetaTensor from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DecoderOnlyTransformer, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import Scheduler +from monai.transforms import CenterSpatialCrop, SpatialPad +from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + logger = get_logger(__name__) __all__ = [ @@ -752,3 +769,1264 @@ def network_wrapper( return out return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] + super().__init__() + + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + if condition is None: + raise ValueError("Conditioning is required for concat condition") + else: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + if inputs.shape != means.shape: + raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}") + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + prediction: torch.Tensor = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + seg=seg, + ) + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and" + f"{diffusion_model.label_nc}" + ) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class ControlNetDiffusionInferer(DiffusionInferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) + if mode == "concat" and condition is not None: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + prediction: torch.Tensor = diffuse( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + ) + # 2. predict noise model_output + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + ) + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -super()._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode, + seg=seg, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(nn.Module): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()) + else: + start = 0 + prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + if return_latent: + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: tuple[int, int, int] | tuple[int, int], + starting_tokens: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 1532215c70..0441cc9cfe 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width, depth = x.shape # norm - x = self.norm(x) + x = self.norm(x.contiguous()) if self.spatial_dims == 2: x = x.view(batch, channel, height * width).transpose(1, 2) @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x + h = x.contiguous() h = self.norm1(h) h = self.nonlinearity(h) @@ -1957,7 +1957,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h) + output: torch.Tensor = self.out(h.contiguous()) return output diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index d53327100e..bffc9c5465 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -618,7 +618,7 @@ class SPADEDiffusionModelUNet(nn.Module): out_channels: number of output channels. label_nc: number of semantic channels for SPADE normalisation. num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. + channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. @@ -641,7 +641,7 @@ def __init__( out_channels: int, label_nc: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, @@ -667,10 +667,10 @@ def __init__( ) # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): @@ -683,9 +683,9 @@ def __init__( ) if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -700,7 +700,7 @@ def __init__( ) self.in_channels = in_channels - self.block_out_channels = num_channels + self.block_out_channels = channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels @@ -712,7 +712,7 @@ def __init__( self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -720,9 +720,9 @@ def __init__( ) # time - time_embed_dim = num_channels[0] * 4 + time_embed_dim = channels[0] * 4 self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding @@ -732,11 +732,11 @@ def __init__( # down self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): + output_channel = channels[0] + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, @@ -762,7 +762,7 @@ def __init__( # mid self.middle_block = get_mid_block( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -776,7 +776,7 @@ def __init__( # up self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) + reversed_block_out_channels = list(reversed(channels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) @@ -784,9 +784,9 @@ def __init__( for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] - is_final_block = i == len(num_channels) - 1 + is_final_block = i == len(channels) - 1 up_block = get_spade_up_block( spatial_dims=spatial_dims, @@ -814,12 +814,12 @@ def __init__( # out self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), nn.SiLU(), zero_module( Convolution( spatial_dims=spatial_dims, - in_channels=num_channels[0], + in_channels=channels[0], out_channels=out_channels, strides=1, kernel_size=3, diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -126,6 +126,7 @@ version_leq, ) from .nvtx import Range +from .ordering import Ordering from .profiling import ( PerfContext, ProfileHandler, diff --git a/setup.cfg b/setup.cfg index 123da68dfa..0069214de3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ all = scipy>=1.7.1 pillow tensorboard - gdown>=4.4.0 + gdown==4.6.3 pytorch-ignite==0.4.11 torchvision itk>=5.2 @@ -60,12 +60,12 @@ all = lmdb psutil cucim>=23.2.0 - openslide-python==1.1.2 + openslide-python tifffile imagecodecs pandas einops - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib @@ -97,7 +97,7 @@ pillow = tensorboard = tensorboard gdown = - gdown>=4.4.0 + gdown==4.6.3 ignite = pytorch-ignite==0.4.11 torchvision = @@ -113,7 +113,7 @@ psutil = cucim = cucim>=23.2.0 openslide = - openslide-python==1.1.2 + openslide-python tifffile = tifffile imagecodecs = @@ -123,7 +123,7 @@ pandas = einops = einops transformers = - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow = mlflow matplotlib = @@ -173,6 +173,7 @@ max_line_length = 120 # B028 https://github.com/Project-MONAI/MONAI/issues/5855 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 +# B036 https://github.com/Project-MONAI/MONAI/issues/7396 ignore = E203 E501 @@ -186,6 +187,7 @@ ignore = B028 B907 B908 + B036 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 0000000000..1f675537dc --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1270 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_intermediates( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_get_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_resample_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py new file mode 100644 index 0000000000..ecd4855385 --- /dev/null +++ b/tests/test_diffusion_inferer.py @@ -0,0 +1,226 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 1218ce6e85..1d831f0976 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -39,7 +39,7 @@ class DummyEncoder(BaseEncoder): def get_encoder_parameters(cls): basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False} param_dict_list = [basic_dict] - for key in basic_dict: + for key in basic_dict.keys(): cur_dict = basic_dict.copy() del cur_dict[key] param_dict_list.append(cur_dict) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index cd2e91257a..2e6ee35981 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -112,15 +112,15 @@ def test_invert(self): self.assertTupleEqual(i.shape[1:], (101, 100, 107)) # check the case that different items use different interpolation mode to invert transforms - d = item["image_inverted1"] + j = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(j.shape, (1, 101, 100, 107)) - d = item["label_inverted1"] + k = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(k.shape, (1, 101, 100, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py new file mode 100644 index 0000000000..4ab803bb6f --- /dev/null +++ b/tests/test_latent_diffusion_inferer.py @@ -0,0 +1,796 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInferer +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + save_intermediates=True, + intermediate_steps=1, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 0c52dba5e5..e6b235e179 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -182,24 +182,6 @@ ], ] -TEST_2D_RANDOM = [ - [ - { - "ordering_type": OrderingType.RANDOM, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [[0, 1, 2, 3], [0, 1, 3, 2]], - ] -] TEST_3D = [ [ @@ -291,17 +273,6 @@ def test_ordering_transformation_failure(self, input_param): with self.assertRaises(ValueError): Ordering(**input_param) - @parameterized.expand(TEST_2D_RANDOM) - def test_random(self, input_param, not_in_expected_sequence_ordering): - ordering = Ordering(**input_param) - - not_in = [ - np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) - for sequence in not_in_expected_sequence_ordering - ] - - self.assertFalse(np.any(not_in)) - @parameterized.expand(TEST_REVERT) def test_revert(self, input_param): sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() diff --git a/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py similarity index 100% rename from test_spade_autoencoderkl.py rename to tests/test_spade_autoencoderkl.py diff --git a/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py similarity index 92% rename from test_spade_diffusion_model_unet.py rename to tests/test_spade_diffusion_model_unet.py index c8a2103cf6..113e58ed89 100644 --- a/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -26,7 +26,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -38,7 +38,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": (1, 1, 2), - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -50,7 +50,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -63,7 +63,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -76,7 +76,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -90,7 +90,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -103,7 +103,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, True, True), "num_head_channels": (0, 2, 4), "norm_num_groups": 8, @@ -119,7 +119,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -132,7 +132,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -144,7 +144,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -157,7 +157,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -170,7 +170,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -184,7 +184,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -197,7 +197,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": (0, 0, 4), "norm_num_groups": 8, @@ -213,7 +213,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -229,7 +229,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -246,7 +246,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -279,7 +279,7 @@ def test_timestep_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -296,7 +296,7 @@ def test_label_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -313,7 +313,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -331,7 +331,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 12), + channels=(8, 8, 12), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -344,13 +344,13 @@ def test_attention_levels_with_different_length_num_head_channels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), num_head_channels=(0, 2), norm_num_groups=8, ) - def test_num_res_blocks_with_different_length_num_channels(self): + def test_num_res_blocks_with_different_length_channels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -358,7 +358,7 @@ def test_num_res_blocks_with_different_length_num_channels(self): in_channels=1, out_channels=1, num_res_blocks=(1, 1), - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -370,7 +370,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -395,7 +395,7 @@ def test_with_conditioning_cross_attention_dim_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -410,7 +410,7 @@ def test_context_with_conditioning_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=False, transformer_num_layers=1, @@ -433,7 +433,7 @@ def test_shape_conditioned_models_class_conditioning(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -455,7 +455,7 @@ def test_conditioned_models_no_class_labels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -469,7 +469,7 @@ def test_conditioned_models_no_class_labels(self): seg=torch.rand((1, 3, 16, 32)), ) - def test_model_num_channels_not_same_size_of_attention_levels(self): + def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -477,7 +477,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False), norm_num_groups=8, num_head_channels=8, @@ -518,7 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=4, ) @@ -537,7 +537,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(16, 16, 16), + channels=(16, 16, 16), attention_levels=(False, False, True), norm_num_groups=16, with_conditioning=True, diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 0000000000..1a511d287b --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import VQVAETransformerInferer +from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils.ordering import Ordering, OrderingType + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 4, 17), + (2, 2, 2), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 8, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 8, 17), + (2, 2, 2, 2), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, logits_shape) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=2, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() From 41fb3ff8af39529b0641c9b1d3341987cafac62b Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:21:27 +0100 Subject: [PATCH 12/32] [Attention block] relative positional embedding (#7346) Fixes #7356 ### Description Add relative positinoal embedding in attention block as described in https://arxiv.org/pdf/2112.01526.pdf Largely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py Can be useful for #6357 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 6 + monai/networks/blocks/attention_utils.py | 128 +++++++++++++++++++++ monai/networks/blocks/rel_pos_embedding.py | 56 +++++++++ monai/networks/blocks/selfattention.py | 33 +++++- monai/networks/layers/factories.py | 13 ++- monai/networks/layers/utils.py | 15 ++- tests/test_selfattention.py | 21 +++- 7 files changed, 262 insertions(+), 10 deletions(-) create mode 100644 monai/networks/blocks/attention_utils.py create mode 100644 monai/networks/blocks/rel_pos_embedding.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index f9375f1e97..556bf12d50 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -248,6 +248,12 @@ Blocks .. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock :members: +`Attention utilities` +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.attention_utils +.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos +.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py new file mode 100644 index 0000000000..8c9002a16e --- /dev/null +++ b/monai/networks/blocks/attention_utils.py @@ -0,0 +1,128 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple +) -> torch.Tensor: + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). + + Returns: + attn (Tensor): attention logits with added relative positional embeddings. + """ + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) + + batch, _, dim = q.shape + + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) + + return attn diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py new file mode 100644 index 0000000000..e53e5841b0 --- /dev/null +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -0,0 +1,56 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterable, Tuple + +import torch +from torch import nn + +from monai.networks.blocks.attention_utils import add_decomposed_rel_pos +from monai.utils.misc import ensure_tuple_size + + +class DecomposedRelativePosEmbedding(nn.Module): + def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: + """ + Args: + s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) + c_dim (int): channel dimension + num_heads(int): number of attention heads + """ + super().__init__() + + # validate inputs + if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: + raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") + + self.s_input_dims = s_input_dims + self.c_dim = c_dim + self.num_heads = num_heads + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] + ) + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """""" + batch = x.shape[0] + h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) + + att_mat = add_decomposed_rel_pos( + att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), + q.contiguous().view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), + ) + + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + return att_mat diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..3bef24b4e8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,9 +11,12 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -23,6 +26,7 @@ class SABlock(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in """ def __init__( @@ -32,6 +36,8 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, ) -> None: """ Args: @@ -39,6 +45,10 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -62,11 +72,30 @@ def __init__( self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C - def forward(self, x): + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 4fc2c16f73..29b72a4f37 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -70,7 +70,7 @@ def use_factory(fact_args): from monai.networks.utils import has_nvfuser_instance_norm from monai.utils import ComponentStore, look_up_option, optional_import -__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] +__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"] class LayerFactory(ComponentStore): @@ -201,6 +201,10 @@ def split_args(args): Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") +RelPosEmbedding = LayerFactory( + name="Relative positional embedding layers", + description="Factory for creating relative positional embedding factory", +) @Dropout.factory_function("dropout") @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] + + +@RelPosEmbedding.factory_function("decomposed") +def decomposed_rel_pos_embedding() -> type[nn.Module]: + from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding + + return DecomposedRelativePosEmbedding diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index ace1af27b6..8676f74638 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -11,9 +11,11 @@ from __future__ import annotations +from typing import Optional + import torch.nn -from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args +from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args from monai.utils import has_option __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): pool_name, pool_args = split_args(name) pool_type = Pool[pool_name, spatial_dims] return pool_type(**pool_args) + + +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): + embedding_name, embedding_args = split_args(name) + embedding_type = RelPosEmbedding[embedding_name] + # create a dictionary with the default values which can be overridden by embedding_args + kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} + # filter out unused argument names + kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} + + return embedding_type(**kw_args) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 6062b5352f..0d0553ed2c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -20,6 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock +from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -28,12 +29,20 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From f15a173e72e7a36188c98156f1d06da41d895075 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 1 Feb 2024 16:34:53 +0000 Subject: [PATCH 13/32] 6676 port generative engines (#7406) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of #6676 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li Signed-off-by: Mark Graham Signed-off-by: dongy Signed-off-by: KumoLiu Signed-off-by: myron Signed-off-by: kaibo Signed-off-by: monai-bot Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Ishan Dutta Signed-off-by: dependabot[bot] Signed-off-by: Mark Graham Signed-off-by: vgrau98 Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Signed-off-by: heyufan1995 Signed-off-by: binliu Signed-off-by: axel.vlaminck Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Co-authored-by: Dong Yang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: myron Co-authored-by: Kaibo Tang <99367900+kvttt@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ishan Dutta Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: Kaibo Tang Co-authored-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: axel.vlaminck --- .github/workflows/chatops.yml | 2 +- .github/workflows/codeql-analysis.yml | 4 +- .github/workflows/cron-ngc-bundle.yml | 2 +- .github/workflows/docker.yml | 6 +- .github/workflows/pythonapp-min.yml | 6 +- .github/workflows/pythonapp.yml | 8 +- .github/workflows/release.yml | 10 +- .github/workflows/setupapp.yml | 4 +- .github/workflows/weekly-preview.yml | 2 +- docs/requirements.txt | 2 +- docs/source/engines.rst | 5 + docs/source/losses.rst | 5 + monai/apps/auto3dseg/data_analyzer.py | 26 +- monai/apps/detection/utils/anchor_utils.py | 8 +- monai/apps/utils.py | 2 +- monai/auto3dseg/analyzer.py | 4 +- monai/data/decathlon_datalist.py | 6 +- monai/data/image_reader.py | 12 +- monai/engines/__init__.py | 4 +- monai/engines/evaluator.py | 51 ++- monai/engines/trainer.py | 335 +++++++++++++++++- monai/engines/utils.py | 77 +++- monai/handlers/mlflow_handler.py | 2 +- monai/losses/__init__.py | 2 +- monai/losses/deform.py | 96 ++++- monai/losses/image_dissimilarity.py | 4 +- monai/metrics/hausdorff_distance.py | 2 +- monai/metrics/surface_dice.py | 2 +- monai/metrics/surface_distance.py | 2 +- monai/metrics/utils.py | 12 +- monai/networks/nets/swin_unetr.py | 2 +- monai/networks/nets/transchex.py | 49 +-- monai/networks/nets/vqvae.py | 14 +- monai/transforms/croppad/dictionary.py | 9 +- monai/transforms/inverse.py | 12 +- monai/transforms/io/array.py | 17 +- monai/transforms/utility/array.py | 11 +- monai/transforms/utility/dictionary.py | 6 +- monai/utils/dist.py | 9 +- monai/utils/misc.py | 6 +- requirements-dev.txt | 6 +- setup.cfg | 2 + tests/min_tests.py | 1 + tests/padders.py | 3 + tests/test_diffusion_loss.py | 116 ++++++ tests/test_hilbert_transform.py | 20 +- tests/test_image_filter.py | 16 + .../test_integration_workflows_adversarial.py | 173 +++++++++ tests/test_prepare_batch_diffusion.py | 104 ++++++ tests/test_save_image.py | 16 + tests/test_set_visible_devices.py | 7 + tests/test_spacing.py | 8 +- 52 files changed, 1155 insertions(+), 155 deletions(-) create mode 100644 tests/test_diffusion_loss.py create mode 100644 tests/test_integration_workflows_adversarial.py create mode 100644 tests/test_prepare_batch_diffusion.py diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index b4e201a0d9..59c7d070b4 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v3.0.1 + uses: peter-evans/slash-command-dispatch@v3.0.2 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 3d32ae407a..18f1519b5a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -72,4 +72,4 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml index 0bba630d03..84666204a9 100644 --- a/.github/workflows/cron-ngc-bundle.yml +++ b/.github/workflows/cron-ngc-bundle.yml @@ -19,7 +19,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f51e4fdf76..229ae675f5 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -26,7 +26,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash @@ -36,7 +36,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py @@ -56,7 +56,7 @@ jobs: with: ref: dev - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: docker_build diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 558c270e33..7b7930bdf5 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -30,7 +30,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -76,7 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Prepare pip wheel @@ -121,7 +121,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index ad8b555dd4..29a79759e0 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -28,7 +28,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -69,7 +69,7 @@ jobs: disk-root: "D:" - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -128,7 +128,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -209,7 +209,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7197215486..a03d2cea6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install setuptools @@ -66,7 +66,7 @@ jobs: - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist path: dist/ @@ -97,7 +97,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash @@ -108,7 +108,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py @@ -125,7 +125,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: Set tag diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 0ff7162bee..82394a86dd 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -83,7 +83,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: cache weekly timestamp @@ -120,7 +120,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index c631982745..e94e1dac5a 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -14,7 +14,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - name: Install setuptools diff --git a/docs/requirements.txt b/docs/requirements.txt index a9bbc384f8..e5bedf8552 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 tensorboardX diff --git a/docs/source/engines.rst b/docs/source/engines.rst index afb2682822..a015c7b2a3 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -30,6 +30,11 @@ Workflows .. autoclass:: GanTrainer :members: +`AdversarialTrainer` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: AdversarialTrainer + :members: + `Evaluator` ~~~~~~~~~~~ .. autoclass:: Evaluator diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 568c7dfc77..e929e9d605 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -96,6 +96,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 9280fb5be5..15e56abfea 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -28,7 +28,7 @@ from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd -from monai.utils import StrEnum, min_version, optional_import +from monai.utils import ImageMetaKey, StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys @@ -343,19 +343,25 @@ def _get_all_case_stats( d = summarizer(batch_data) except BaseException as err: if "image_meta_dict" in batch_data.keys(): - filename = batch_data["image_meta_dict"]["filename_or_obj"] + filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ] else: - filename = batch_data[self.image_key].meta["filename_or_obj"] + filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ] logger.info(f"Unable to process data {filename} on {device}. {err}") if self.device.type == "cuda": logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.") - batch_data[self.image_key] = batch_data[self.image_key].to("cpu") - if self.label_key is not None: - label = batch_data[self.label_key] - if not _label_argmax: - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to("cpu") - d = summarizer(batch_data) + try: + batch_data[self.image_key] = batch_data[self.image_key].to("cpu") + if self.label_key is not None: + label = batch_data[self.label_key] + if not _label_argmax: + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to("cpu") + d = summarizer(batch_data) + except BaseException as err: + logger.info(f"Unable to process data {filename} on {device}. {err}") + continue + else: + continue stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index baaa7ce874..283169b653 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -369,8 +369,12 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator): def __init__( self, feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), - base_anchor_shapes: Sequence[Sequence[int]] - | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)), + base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = ( + (32, 32, 32), + (48, 20, 20), + (20, 48, 20), + (20, 20, 48), + ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index d2dd63b958..442dbabba0 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -30,7 +30,7 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import -gdown, has_gdown = optional_import("gdown", "4.4") +gdown, has_gdown = optional_import("gdown", "4.6.3") if TYPE_CHECKING: from tqdm import tqdm diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 654999d439..56419da4cb 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -256,7 +256,7 @@ def __call__(self, data): ) report[ImageStatsKeys.SIZEMM] = [ - int(a * b) for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) ] report[ImageStatsKeys.INTENSITY] = [ @@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe torch.set_grad_enabled(False) ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key] # (H,W,D) + ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 0823d11834..2361bb63a7 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -168,8 +168,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing follows the numpy convention; - otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -1323,7 +1323,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1344,7 +1344,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, header: dict) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1353,8 +1353,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = img.header["space directions"] - origin = img.header["space origin"] + direction = header["space directions"] + origin = header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d8dc51f620..93cc40e292 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,12 +12,14 @@ from __future__ import annotations from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator -from .trainer import GanTrainer, SupervisedTrainer, Trainer +from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( + DiffusionPrepareBatch, IterationEvents, PrepareBatch, PrepareBatchDefault, PrepareBatchExtraInput, + VPredictionPrepareBatch, default_make_latent, default_metric_cmp_fn, default_prepare_batch, diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 119853d5c5..2c8dfe6b85 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,12 +11,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader from monai.config import IgniteInfo, KeysCollection +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,7 +27,7 @@ from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ @@ -238,6 +244,8 @@ def __init__( decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -259,8 +267,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: @@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 61b7028e11..c1364fe015 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -18,13 +19,15 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import GanKeys, min_version, optional_import +from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys +from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -34,7 +37,7 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] class Trainer(Workflow): @@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ def __init__( @@ -153,6 +159,8 @@ def __init__( optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -174,8 +182,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer @@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -231,6 +266,19 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output @@ -423,3 +471,282 @@ def _iteration( GanKeys.GLOSS: g_loss.item(), GanKeys.DLOSS: d_total_loss.item(), } + + +class AdversarialTrainer(Trainer): + """ + Standard supervised training workflow for adversarial loss enabled neural networks. + + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for engine to run. + train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata. + g_network: ''generator'' (G) network architecture. + g_optimizer: G optimizer function. + g_loss_function: G loss function for adversarial training. + recon_loss_function: G loss function for reconstructions. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer function. + d_loss_function: D loss function for adversarial training.. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to + the host. For other cases, this argument has no effect. + prepare_batch: function to parse image and label for current iteration. + iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input + parameters. if not provided, use `self._iteration()` instead. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based + transforms composed by `Compose`. Defaults to None + key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics + when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args + (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and + `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, recommend + `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device | str, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + g_loss_function: Callable, + recon_loss_function: Callable, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + d_loss_function: Callable, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + ): + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.register_events(*AdversarialIterationEvents) + + self.state.g_network = g_network + self.state.g_optimizer = g_optimizer + self.state.g_loss_function = g_loss_function + self.state.recon_loss_function = recon_loss_function + + self.state.d_network = d_network + self.state.d_optimizer = d_optimizer + self.state.d_loss_function = d_loss_function + + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + + self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None + + self.optim_set_to_none = optim_set_to_none + self._complete_state_dict_user_keys() + + def _complete_state_dict_user_keys(self) -> None: + """ + This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for + checkpoint saving. + + Follows the example found at: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict + """ + self._state_dict_user_keys.extend( + ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"] + ) + + g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None) + if callable(g_loss_state_dict): + self._state_dict_user_keys.append("g_loss_function") + + d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None) + if callable(d_loss_state_dict): + self._state_dict_user_keys.append("d_loss_function") + + recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None) + if callable(recon_loss_state_dict): + self._state_dict_user_keys.append("recon_loss_function") + + def _iteration( + self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | int | float | bool]: + """ + Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised + Learning this is equal to IMAGE. + - PRED: prediction result of model. + - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up). + - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE. + - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED. + - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images. + - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images. + - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function. + - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + + Args: + engine: `AdversarialTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + + Raises: + ValueError: must provide batch data for current iteration. + + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + + if len(batch) == 2: + inputs, targets = batch + args: tuple = () + kwargs: dict = {} + else: + inputs, targets, args, kwargs = batch + + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs} + + def _compute_generator_loss() -> None: + engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer( + inputs, engine.state.g_network, *args, **kwargs + ) + engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES] + engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs + ) + engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function( + engine.state.output[AdversarialKeys.FAKES], targets + ).mean() + engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED) + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function( + engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED) + + # Train Generator + engine.state.g_network.train() + engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.g_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_generator_loss() + + engine.state.output[Keys.LOSS] = ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ) + engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_scaler.step(engine.state.g_optimizer) + engine.state.g_scaler.update() + else: + _compute_generator_loss() + ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_optimizer.step() + engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED) + + def _compute_discriminator_loss() -> None: + engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.REALS].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function( + engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED) + + # Train Discriminator + engine.state.d_network.train() + engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.d_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_discriminator_loss() + + engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED) + engine.state.d_scaler.step(engine.state.d_optimizer) + engine.state.d_scaler.update() + else: + _compute_discriminator_loss() + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward() + engine.state.d_optimizer.step() + + return engine.state.output diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 02c718cd14..5339d6965a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -13,9 +13,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import torch +import torch.nn as nn from monai.config import IgniteInfo from monai.transforms import apply_transform @@ -36,6 +37,8 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", + "DiffusionPrepareBatch", + "VPredictionPrepareBatch", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -238,6 +241,78 @@ def _get_data(key: str) -> torch.Tensor: return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images: torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, + non_blocking: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) + + def default_make_latent( num_latents: int, latent_size: int, diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index a2bd345dc6..df209c1c8b 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -401,7 +401,7 @@ def _default_iteration_log(self, engine: Engine) -> None: cur_optimizer = engine.optimizer for param_name in self.optimizer_param_names: params = { - f"{param_name} group_{i}": float(param_group[param_name]) + f"{param_name}_group_{i}": float(param_group[param_name]) for i, param_group in enumerate(cur_optimizer.param_groups) } self._log_metrics(params, step=engine.state.iteration) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..37e4468d4b 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -46,7 +46,10 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of pred using central finite difference. + Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference. + + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -75,6 +78,9 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: @@ -84,7 +90,8 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( - f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" ) # first order gradient @@ -116,3 +123,88 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g**2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 39219e059a..dd132770ec 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -277,9 +277,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torc if order == 0: weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = ( - weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 - ) + weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: raise ValueError(f"Do not support b-spline {order}-order parzen windowing") diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index d9bbf17db3..d727eb0567 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -190,7 +190,7 @@ def compute_hausdorff_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=not directed, + symmetric=not directed, class_index=c, ) percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 635eb1bc24..b20b47a1a5 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -253,7 +253,7 @@ def compute_surface_dice( distance_metric=distance_metric, spacing=spacing_list[b], use_subvoxels=use_subvoxels, - symetric=True, + symmetric=True, class_index=c, ) boundary_correct: int | torch.Tensor | float diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7ce632c588..3cb336d6a0 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -177,7 +177,7 @@ def compute_average_surface_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=symmetric, + symmetric=symmetric, class_index=c, ) surface_distance = torch.cat(distances) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 62e6520b96..e7057256fb 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -38,10 +38,6 @@ binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") -cucim_distance_transform_edt, has_cucim_distance_transform_edt = optional_import( - "cucim.core.operations.morphology", name="distance_transform_edt" -) __all__ = [ "ignore_background", @@ -179,6 +175,8 @@ def get_mask_edges( always_return_as_numpy: whether to a numpy array regardless of the input type. If False, return the same type as inputs. """ + # move in the funciton to avoid using all the GPUs + cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") converter: Any @@ -295,7 +293,7 @@ def get_edge_surface_distance( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, use_subvoxels: bool = False, - symetric: bool = False, + symmetric: bool = False, class_index: int = -1, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], @@ -314,7 +312,7 @@ def get_edge_surface_distance( See :py:func:`monai.metrics.utils.get_surface_distance`. use_subvoxels: whether to use subvoxel resolution (using the spacing). This will return the areas of the edges. - symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. + symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. Returns: @@ -338,7 +336,7 @@ def get_edge_surface_distance( " this may result in nan/inf distance." ) distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] - if symetric: + if symmetric: distances = ( get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 10c4ce3d8e..6f96dfd291 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1024,7 +1024,7 @@ def __init__( self.layers4.append(layer) if self.use_v2: layerc = UnetrBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=embed_dim * 2**i_layer, out_channels=embed_dim * 2**i_layer, kernel_size=3, diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index ff27903cef..6bfff3c956 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -12,20 +12,17 @@ from __future__ import annotations import math -import os -import shutil -import tarfile -import tempfile from collections.abc import Sequence import torch from torch import nn +from monai.config.type_definitions import PathLike from monai.utils import optional_import transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0] -cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +cached_file = optional_import("transformers.utils", name="cached_file")[0] BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] @@ -63,44 +60,16 @@ def from_pretrained( state_dict=None, cache_dir=None, from_tf=False, + path_or_repo_id="bert-base-uncased", + filename="pytorch_model.bin", *inputs, **kwargs, ): - archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - tempdir = tempfile.mkdtemp() - with tarfile.open(resolved_archive_file, "r:gz") as archive: - - def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(archive, tempdir) - serialization_dir = tempdir + weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir) model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, "pytorch_model.bin") state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) - if tempdir: - shutil.rmtree(tempdir) if from_tf: - weights_path = os.path.join(serialization_dir, "model.ckpt") return load_tf_weights_in_bert(model, weights_path) old_keys = [] new_keys = [] @@ -304,6 +273,8 @@ def __init__( chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, + path_or_repo_id: str | PathLike = "bert-base-uncased", + filename: str = "pytorch_model.bin", ) -> None: """ Args: @@ -315,6 +286,10 @@ def __init__( num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: fraction of the input units to drop. + path_or_repo_id: This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filename: The name of the file to locate in `path_or_repo`. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. @@ -369,6 +344,8 @@ def __init__( num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, + path_or_repo_id=path_or_repo_id, + filename=filename, ) self.patch_size = patch_size diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py index d4771e203a..f198bfbb2b 100644 --- a/monai/networks/nets/vqvae.py +++ b/monai/networks/nets/vqvae.py @@ -312,10 +312,16 @@ def __init__( channels: Sequence[int] = (96, 96, 192), num_res_layers: int = 3, num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: Sequence[Tuple[int, int, int, int]] - | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Sequence[Tuple[int, int, int, int, int]] - | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = ( + (2, 4, 1, 1), + (2, 4, 1, 1), + (2, 4, 1, 1), + ), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = ( + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + ), num_embeddings: int = 32, embedding_dim: int = 64, embedding_init: str = "normal", diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 56d214c51d..be9441dc4a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -221,9 +221,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class BorderPadd(Padd): @@ -274,9 +273,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class DivisiblePadd(Padd): @@ -324,9 +322,8 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - LazyTransform.__init__(self, lazy) padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class Cropd(MapTransform, InvertibleTransform, LazyTransform): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 41fabb35aa..f94f11eca9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -185,7 +185,17 @@ def track_transform_meta( # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] - affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + try: + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + except RuntimeError as e: + if orig_affine.ndim > 2: + if data_t.is_batch: + msg = "Transform applied to batched tensor, should be applied to instances only" + else: + msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation." + raise RuntimeError(msg) from e + else: + raise out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64) if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd7e4ef090..7222a26fc3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -414,6 +414,9 @@ def __init__( self.fname_formatter = output_name_formatter self.output_ext = output_ext.lower() or output_format.lower() + self.output_ext = ( + f".{self.output_ext}" if self.output_ext and not self.output_ext.startswith(".") else self.output_ext + ) if isinstance(writer, str): writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in if not has_built_in: @@ -458,15 +461,23 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ self.write_kwargs.update(write_kwargs) return self - def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None): + def __call__( + self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None + ): """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of metadata corresponding to the data. + filename: str or file-like object which to save img. + If specified, will ignore `self.output_name_formatter` and `self.folder_layout`. """ meta_data = img.meta if isinstance(img, MetaTensor) else meta_data - kw = self.fname_formatter(meta_data, self) - filename = self.folder_layout.filename(**kw) + if filename is not None: + filename = f"{filename}{self.output_ext}" + else: + kw = self.fname_formatter(meta_data, self) + filename = self.folder_layout.filename(**kw) + if meta_data: meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ())) if len(meta_spatial_shape) >= len(img.shape): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2322f2123f..5dfbcb0e91 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + def __call__( + self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None + ) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + applied_operations: An optional list of operations that have been applied to the data Returns: A MetaTensor with the same shape as `img` and identical metadata """ if isinstance(img, MetaTensor): meta_dict = img.meta + applied_operations = img.applied_operations + img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict: - img_ = MetaTensor(img_, meta=meta_dict) + if meta_dict is not None or applied_operations is not None: + img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ec10bd8537..1cd9ff6323 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1765,9 +1765,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd -ConvertToMultiChannelBasedOnBratsClassesD = ( - ConvertToMultiChannelBasedOnBratsClassesDict -) = ConvertToMultiChannelBasedOnBratsClassesd +ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = ( + ConvertToMultiChannelBasedOnBratsClassesd +) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4f2501a7ee..81f582daef 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: diff --git a/requirements-dev.txt b/requirements-dev.txt index 6332d5b0a5..f8bc9d5a3e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Full requirements for developments -r requirements-min.txt pytorch-ignite==0.4.11 -gdown>=4.4.0 +gdown>=4.4.0, <=4.6.3 scipy>=1.7.1 itk>=5.2 nibabel @@ -27,13 +27,13 @@ ninja torchvision psutil cucim>=23.2.0; platform_system == "Linux" -openslide-python==1.1.2 +openslide-python imagecodecs; platform_system == "Linux" or platform_system == "Darwin" tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers>=4.36.0 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/setup.cfg b/setup.cfg index 0069214de3..4180ced917 100644 --- a/setup.cfg +++ b/setup.cfg @@ -174,6 +174,7 @@ max_line_length = 120 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 # B036 https://github.com/Project-MONAI/MONAI/issues/7396 +# E704 https://github.com/Project-MONAI/MONAI/issues/7421 ignore = E203 E501 @@ -188,6 +189,7 @@ ignore = B907 B908 B036 + E704 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/min_tests.py b/tests/min_tests.py index 8128bb7b84..3a143df84b 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -154,6 +154,7 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_prepare_batch_default", + "test_prepare_batch_diffusion", "test_prepare_batch_extra_input", "test_prepare_batch_hovernet", "test_rand_grid_patch", diff --git a/tests/padders.py b/tests/padders.py index 02d7b40af6..ae1153bdfd 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -136,6 +136,9 @@ def pad_test_pending_ops(self, input_param, input_shape): # TODO: mode="bilinear" may report error overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} result = apply_pending(pending_result, overrides=overrides)[0] + # lazy in constructor + pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param) + self.assertTrue(pad_fn_lazy.lazy) # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 4c49aecd8b..68fa0b1192 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -180,15 +180,17 @@ def test_value(self, arguments, image, expected_data, atol): @SkipIfNoModule("torch.fft") class TestHilbertTransformGPU(unittest.TestCase): @parameterized.expand( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ], + ( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ] + ), skip_on_empty=True, ) def test_value(self, arguments, image, expected_data, atol): diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 841a5d5cd5..985ea95e79 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd @@ -115,6 +116,21 @@ def test_call_3d(self, filter_name): out_tensor = filter(SAMPLE_IMAGE_3D) self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + def test_pass_applied_operations(self): + "Test that applied operations are passed through" + applied_operations = ["op1", "op2"] + image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertEqual(out_tensor.applied_operations, applied_operations) + + def test_pass_empty_metadata_dict(self): + "Test that applied operations are passed through" + image = MetaTensor(SAMPLE_IMAGE_2D, meta={}) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertTrue(isinstance(out_tensor, MetaTensor)) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py new file mode 100644 index 0000000000..f323fc9917 --- /dev/null +++ b/tests/test_integration_workflows_adversarial.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from glob import glob + +import numpy as np +import torch + +import monai +from monai.data import create_test_image_2d +from monai.engines import AdversarialTrainer +from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler +from monai.networks.nets import AutoEncoder, Discriminator +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd +from monai.utils import AdversarialKeys as Keys +from monai.utils import CommonKeys, optional_import, set_determinism +from tests.utils import DistTestCase, TimedCall, skip_if_quick + +nib, has_nibabel = optional_import("nibabel") + + +def run_training_test(root_dir, device="cuda:0"): + learning_rate = 2e-4 + real_label = 1 + fake_label = 0 + + real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) + train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)] + + # prepare real data + train_transforms = Compose( + [ + LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]), + EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2), + ScaleIntensityd(keys=[CommonKeys.IMAGE]), + RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5), + ] + ) + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) + train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) + + # Create Discriminator + discriminator_net = Discriminator( + in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5 + ).to(device) + discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate) + discriminator_loss_criterion = torch.nn.BCELoss() + + def discriminator_loss(real_logits, fake_logits): + real_target = real_logits.new_full((real_logits.shape[0], 1), real_label) + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label) + real_loss = discriminator_loss_criterion(real_logits, real_target) + fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return torch.div(torch.add(real_loss, fake_loss), 2) + + # Create Generator + generator_network = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 16, 32, 64), + strides=(2, 2, 2, 2), + num_res_units=1, + num_inter_units=1, + ) + generator_network = generator_network.to(device) + generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate) + generator_loss_criterion = torch.nn.MSELoss() + + def reconstruction_loss(recon_images, real_images): + return generator_loss_criterion(recon_images, real_images) + + def generator_loss(fake_logits): + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label) + recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return recon_loss + + key_train_metric = None + + train_handlers = [ + StatsHandler( + name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + TensorBoardStatsHandler( + log_dir=root_dir, + tag_name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + CheckpointSaver( + save_dir=root_dir, + save_dict={"g_net": generator_network, "d_net": discriminator_net}, + save_interval=2, + epoch_level=True, + ), + ] + + num_epochs = 5 + + trainer = AdversarialTrainer( + device=device, + max_epochs=num_epochs, + train_data_loader=train_loader, + g_network=generator_network, + g_optimizer=generator_optimiser, + g_loss_function=generator_loss, + recon_loss_function=reconstruction_loss, + d_network=discriminator_net, + d_optimizer=discriminator_opt, + d_loss_function=discriminator_loss, + non_blocking=True, + key_train_metric=key_train_metric, + train_handlers=train_handlers, + ) + trainer.run() + + return trainer.state + + +@skip_if_quick +@unittest.skipUnless(has_nibabel, "Requires nibabel library.") +class IntegrationWorkflowsAdversarialTrainer(DistTestCase): + def setUp(self): + set_determinism(seed=0) + + self.data_dir = tempfile.mkdtemp() + for i in range(40): + im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + monai.config.print_config() + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + @TimedCall(seconds=200, daemon=False) + def test_training(self): + torch.manual_seed(0) + + finish_state = run_training_test(self.data_dir, device=self.device) + + # Assert AdversarialTrainer training finished + self.assertEqual(finish_state.iteration, 100) + self.assertEqual(finish_state.epoch, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py new file mode 100644 index 0000000000..d969c06368 --- /dev/null +++ b/tests/test_prepare_batch_diffusion.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.engines import SupervisedEvaluator +from monai.engines.utils import DiffusionPrepareBatch +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestPrepareBatchDiffusion(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_output_sizes(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + @parameterized.expand(TEST_CASES) + def test_conditioning(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device), "context": torch.randn((2, 4, 3)).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name="context"), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_save_image.py b/tests/test_save_image.py index ba94ab5087..d88db201ce 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -37,6 +37,8 @@ False, ] +TEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), ".dcm", False] + @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImage(unittest.TestCase): @@ -58,6 +60,20 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample): filepath = "testfile0" if meta_data is not None else "0" self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext))) + @parameterized.expand([TEST_CASE_5]) + def test_saved_content_with_filename(self, test_data, output_ext, resample): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + separate_folder=False, # test saving into the same folder + ) + filename = str(os.path.join(tempdir, "test")) + trans(test_data, filename=filename) + + self.assertTrue(os.path.exists(filename + output_ext)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index 53703e107a..993e8a4ac2 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -35,6 +35,13 @@ def test_visible_devices(self): ) self.assertEqual(num_gpus_before, num_gpus_after) + # test import monai won't affect setting CUDA_VISIBLE_DEVICES + num_gpus_after_monai = self.run_process_and_get_exit_code( + 'python -c "import os; import torch; import monai; ' + + "os.environ['CUDA_VISIBLE_DEVICES'] = '0'; exit(torch.cuda.device_count())\"" + ) + self.assertEqual(num_gpus_after_monai, 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 1ff1518297..8b664641d7 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -74,9 +74,11 @@ torch.ones((1, 2, 1, 2)), # data torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), {}, - torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) - if USE_COMPILED - else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ( + torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) + if USE_COMPILED + else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) + ), *device, ] ) From ba188e24c5b6d733b04f743a47ec84513a7dbb7a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 23 Apr 2024 16:18:18 +0100 Subject: [PATCH 14/32] monai generative: refactor autoencoderkl (#7552) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of the refactoring in #7227 ### Description Refactors autoencoderkl. Changes are: - Introduce `CastToTempType` class for upsampling - `Downsample` block removed and replaced by a `Sequential` - The attention block now uses MONAI's `SABlock`, allowing a lot of code to be removed - Added a `load_old_state_dict` that allows for models trained on MONAI Generative to be loaded in to this model, especially important given some of the MONAI Generative's [model zoo](https://github.com/Project-MONAI/GenerativeModels/tree/main/model-zoo) uses this model. I have tested this works locally. I discussed with @ericspod inheriting from `AutoEncoder` but after experimentation have decided against it as it introduced changes that made it very hard to ensure we could load model's trained in MONAI Generative. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: binliu Signed-off-by: dependabot[bot] Signed-off-by: axel.vlaminck Signed-off-by: monai-bot Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: Suraj Pai Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: Ishan Dutta Signed-off-by: John Zielke Signed-off-by: Mingxin Zheng Signed-off-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Signed-off-by: Yiheng Wang Signed-off-by: Szabolcs Botond Lorincz Molnar Signed-off-by: Mark Graham Signed-off-by: Lucas Robinet Signed-off-by: Mingxin Signed-off-by: Han Wang Signed-off-by: Konstantin Sukharev Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: axel.vlaminck Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl Co-authored-by: Suraj Pai Co-authored-by: Juampa <1523654+juampatronics@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Ishan Dutta Co-authored-by: johnzielke Co-authored-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond Co-authored-by: Nic Ma Co-authored-by: Lucas Robinet Co-authored-by: Han Wang Co-authored-by: Konstantin Sukharev <50718389+k-sukharev@users.noreply.github.com> --- .github/workflows/chatops.yml | 2 +- .github/workflows/conda.yml | 2 +- .github/workflows/cron-ngc-bundle.yml | 2 +- .github/workflows/cron.yml | 22 +- .github/workflows/docker.yml | 9 +- .github/workflows/integration.yml | 8 +- .github/workflows/pythonapp-gpu.yml | 10 +- .github/workflows/pythonapp-min.yml | 6 +- .github/workflows/pythonapp.yml | 10 +- .github/workflows/release.yml | 6 +- .github/workflows/setupapp.yml | 14 +- .pre-commit-config.yaml | 4 +- Dockerfile | 14 +- docs/requirements.txt | 2 +- docs/source/config_syntax.md | 7 +- docs/source/losses.rst | 10 + docs/source/networks.rst | 10 + docs/source/transforms.rst | 54 +++ docs/source/transforms_idx.rst | 10 + monai/__init__.py | 5 + monai/apps/auto3dseg/auto_runner.py | 28 +- monai/apps/auto3dseg/bundle_gen.py | 32 +- monai/apps/auto3dseg/ensemble_builder.py | 2 +- monai/apps/auto3dseg/hpo_gen.py | 4 +- monai/apps/datasets.py | 1 + monai/apps/deepedit/transforms.py | 3 + monai/apps/detection/metrics/coco.py | 1 + monai/apps/detection/utils/ATSS_matcher.py | 1 + monai/apps/nnunet/nnunetv2_runner.py | 116 +++--- monai/apps/pathology/transforms/post/array.py | 6 +- monai/apps/utils.py | 2 +- monai/auto3dseg/analyzer.py | 4 +- monai/bundle/config_item.py | 7 +- monai/bundle/scripts.py | 17 +- monai/bundle/utils.py | 6 +- monai/bundle/workflows.py | 117 ++++-- monai/data/dataset.py | 3 +- monai/data/video_dataset.py | 6 +- monai/fl/client/client_algo.py | 1 + monai/handlers/ignite_metric.py | 1 + monai/losses/__init__.py | 2 + monai/losses/barlow_twins.py | 84 +++++ monai/losses/dice.py | 42 ++- monai/losses/focal_loss.py | 5 +- monai/losses/perceptual.py | 59 ++- monai/losses/sure_loss.py | 200 +++++++++++ monai/metrics/f_beta_score.py | 1 + monai/metrics/metric.py | 3 + monai/metrics/regression.py | 2 +- monai/networks/blocks/dynunet_block.py | 1 + monai/networks/blocks/localnet_block.py | 2 + monai/networks/blocks/patchembedding.py | 4 +- monai/networks/blocks/pos_embed_utils.py | 3 +- monai/networks/blocks/upsample.py | 32 +- monai/networks/layers/__init__.py | 3 +- monai/networks/layers/conjugate_gradient.py | 112 ++++++ monai/networks/layers/gmm.py | 1 + monai/networks/layers/simplelayers.py | 2 + monai/networks/layers/spatial_transforms.py | 5 + monai/networks/nets/__init__.py | 2 + monai/networks/nets/ahnet.py | 6 + monai/networks/nets/attentionunet.py | 4 + monai/networks/nets/autoencoderkl.py | 337 ++++++++---------- monai/networks/nets/basic_unet.py | 1 + monai/networks/nets/basic_unetplusplus.py | 1 + monai/networks/nets/densenet.py | 3 + monai/networks/nets/dints.py | 4 + monai/networks/nets/efficientnet.py | 4 + monai/networks/nets/flexible_unet.py | 13 +- monai/networks/nets/highresnet.py | 1 + monai/networks/nets/hovernet.py | 6 + monai/networks/nets/milmodel.py | 1 + monai/networks/nets/regunet.py | 2 + monai/networks/nets/resnet.py | 145 +++++++- monai/networks/nets/spade_autoencoderkl.py | 43 ++- monai/networks/nets/vnet.py | 5 + monai/networks/utils.py | 28 +- monai/optimizers/lr_finder.py | 2 + monai/optimizers/utils.py | 2 + monai/transforms/__init__.py | 16 + monai/transforms/adaptors.py | 4 + monai/transforms/intensity/array.py | 171 ++++++++- monai/transforms/intensity/dictionary.py | 46 ++- monai/transforms/inverse_batch_transform.py | 1 + monai/transforms/post/array.py | 1 + monai/transforms/regularization/__init__.py | 10 + monai/transforms/regularization/array.py | 174 +++++++++ monai/transforms/regularization/dictionary.py | 97 +++++ monai/transforms/smooth_field/array.py | 2 +- monai/transforms/spatial/array.py | 2 +- monai/transforms/utility/dictionary.py | 1 + monai/transforms/utils.py | 48 ++- .../utils_pytorch_numpy_unification.py | 15 + monai/utils/enums.py | 2 +- monai/utils/misc.py | 3 +- monai/utils/module.py | 17 +- monai/utils/profiling.py | 1 + monai/visualize/class_activation_maps.py | 2 + monai/visualize/gradient_based.py | 1 + pyproject.toml | 4 +- requirements-dev.txt | 13 +- runtests.sh | 2 + setup.cfg | 12 +- tests/croppers.py | 1 + tests/hvd_evenly_divisible_all_gather.py | 1 + tests/ngc_bundle_download.py | 2 + tests/nonconfig_workflow.py | 4 +- tests/padders.py | 1 + tests/profile_subclass/min_classes.py | 1 + tests/test_acn_block.py | 1 + tests/test_activations.py | 1 + tests/test_activationsd.py | 1 + tests/test_adaptors.py | 11 + tests/test_add_coordinate_channels.py | 1 + tests/test_add_coordinate_channelsd.py | 1 + tests/test_add_extreme_points_channel.py | 1 + tests/test_add_extreme_points_channeld.py | 1 + tests/test_adjust_contrast.py | 1 + tests/test_adjust_contrastd.py | 1 + tests/test_adn.py | 2 + tests/test_adversarial_loss.py | 1 + tests/test_affine.py | 2 + tests/test_affine_grid.py | 1 + tests/test_affine_transform.py | 28 +- tests/test_affined.py | 1 + tests/test_ahnet.py | 6 + tests/test_anchor_box.py | 1 + tests/test_apply.py | 1 + tests/test_apply_filter.py | 1 + tests/test_arraydataset.py | 2 + tests/test_as_channel_last.py | 1 + tests/test_as_channel_lastd.py | 1 + tests/test_as_discrete.py | 1 + tests/test_as_discreted.py | 1 + tests/test_atss_box_matcher.py | 1 + tests/test_attentionunet.py | 1 + tests/test_auto3dseg.py | 2 +- tests/test_auto3dseg_bundlegen.py | 1 + tests/test_auto3dseg_ensemble.py | 1 + tests/test_auto3dseg_hpo.py | 4 +- tests/test_autoencoder.py | 1 + tests/test_autoencoderkl.py | 42 ++- tests/test_avg_merger.py | 1 + tests/test_barlow_twins_loss.py | 109 ++++++ tests/test_basic_unet.py | 1 + tests/test_basic_unetplusplus.py | 1 + tests/test_bending_energy.py | 1 + tests/test_bilateral_approx_cpu.py | 1 + tests/test_bilateral_approx_cuda.py | 1 + tests/test_bilateral_precise.py | 2 + tests/test_blend_images.py | 1 + tests/test_bounding_rect.py | 1 + tests/test_bounding_rectd.py | 1 + tests/test_box_coder.py | 1 + tests/test_box_transform.py | 1 + tests/test_box_utils.py | 1 + tests/test_bundle_ckpt_export.py | 1 + tests/test_bundle_download.py | 3 + tests/test_bundle_get_data.py | 1 + tests/test_bundle_init_bundle.py | 1 + tests/test_bundle_onnx_export.py | 1 + tests/test_bundle_push_to_hf_hub.py | 1 + tests/test_bundle_trt_export.py | 1 + tests/test_bundle_utils.py | 2 + tests/test_bundle_verify_metadata.py | 1 + tests/test_bundle_verify_net.py | 1 + tests/test_bundle_workflow.py | 19 + tests/test_cachedataset.py | 1 + tests/test_cachedataset_parallel.py | 1 + tests/test_cachedataset_persistent_workers.py | 1 + tests/test_cachentransdataset.py | 1 + tests/test_call_dist.py | 1 + tests/test_cast_to_type.py | 1 + tests/test_cast_to_typed.py | 1 + tests/test_channel_pad.py | 1 + tests/test_check_hash.py | 1 + tests/test_check_missing_files.py | 1 + tests/test_classes_to_indices.py | 1 + tests/test_classes_to_indicesd.py | 1 + tests/test_cldice_loss.py | 1 + tests/test_clip_intensity_percentiles.py | 185 ++++++++++ tests/test_clip_intensity_percentilesd.py | 205 +++++++++++ tests/test_complex_utils.py | 1 + tests/test_component_locator.py | 1 + tests/test_component_store.py | 1 + tests/test_compose.py | 24 +- tests/test_compose_get_number_conversions.py | 7 + tests/test_compute_confusion_matrix.py | 1 + tests/test_compute_f_beta.py | 37 +- tests/test_compute_fid_metric.py | 1 + tests/test_compute_froc.py | 3 + tests/test_compute_generalized_dice.py | 1 + tests/test_compute_ho_ver_maps.py | 1 + tests/test_compute_ho_ver_maps_d.py | 1 + tests/test_compute_meandice.py | 1 + tests/test_compute_meaniou.py | 1 + tests/test_compute_mmd_metric.py | 1 + tests/test_compute_multiscalessim_metric.py | 1 + tests/test_compute_panoptic_quality.py | 1 + tests/test_compute_regression_metrics.py | 1 + tests/test_compute_roc_auc.py | 1 + tests/test_compute_variance.py | 1 + tests/test_concat_itemsd.py | 1 + tests/test_config_item.py | 3 +- tests/test_config_parser.py | 6 +- tests/test_conjugate_gradient.py | 56 +++ tests/test_contrastive_loss.py | 1 + tests/test_convert_data_type.py | 1 + tests/test_convert_to_multi_channel.py | 1 + tests/test_convert_to_multi_channeld.py | 1 + tests/test_convert_to_onnx.py | 20 +- tests/test_convert_to_torchscript.py | 1 + tests/test_convert_to_trt.py | 1 + tests/test_convolutions.py | 3 + tests/test_copy_itemsd.py | 1 + tests/test_copy_model_state.py | 3 + tests/test_correct_crop_centers.py | 1 + .../test_create_cross_validation_datalist.py | 1 + tests/test_create_grid_and_affine.py | 2 + tests/test_crf_cpu.py | 1 + tests/test_crf_cuda.py | 1 + tests/test_crop_foreground.py | 1 + tests/test_crop_foregroundd.py | 1 + tests/test_cross_validation.py | 1 + tests/test_csv_dataset.py | 1 + tests/test_csv_iterable_dataset.py | 1 + tests/test_csv_saver.py | 1 + tests/test_cucim_dict_transform.py | 1 + tests/test_cucim_transform.py | 1 + tests/test_cumulative.py | 1 + tests/test_cumulative_average.py | 1 + tests/test_cumulative_average_dist.py | 1 + tests/test_cv2_dist.py | 1 + tests/test_daf3d.py | 1 + tests/test_data_stats.py | 1 + tests/test_data_statsd.py | 1 + tests/test_dataloader.py | 2 + tests/test_dataset.py | 1 + tests/test_dataset_func.py | 1 + tests/test_dataset_summary.py | 1 + tests/test_decathlondataset.py | 6 +- tests/test_decollate.py | 2 + tests/test_deepedit_interaction.py | 1 + tests/test_deepedit_transforms.py | 11 + tests/test_deepgrow_dataset.py | 1 + tests/test_deepgrow_interaction.py | 1 + tests/test_deepgrow_transforms.py | 11 + tests/test_delete_itemsd.py | 1 + tests/test_denseblock.py | 4 + tests/test_densenet.py | 2 + tests/test_deprecated.py | 8 + tests/test_detect_envelope.py | 2 + tests/test_detection_coco_metrics.py | 1 + tests/test_detector_boxselector.py | 1 + tests/test_detector_utils.py | 1 + tests/test_dev_collate.py | 1 + tests/test_dice_ce_loss.py | 19 +- tests/test_dice_focal_loss.py | 15 +- tests/test_dice_loss.py | 1 + tests/test_diffusion_loss.py | 1 + tests/test_dints_cell.py | 1 + tests/test_dints_mixop.py | 1 + tests/test_dints_network.py | 2 + tests/test_discriminator.py | 1 + tests/test_distance_transform_edt.py | 1 + tests/test_download_and_extract.py | 1 + tests/test_download_url_yandex.py | 1 + tests/test_downsample_block.py | 1 + tests/test_drop_path.py | 1 + tests/test_ds_loss.py | 4 + tests/test_dvf2ddf.py | 1 + tests/test_dynunet.py | 12 +- tests/test_dynunet_block.py | 2 + tests/test_efficientnet.py | 2 + tests/test_ensemble_evaluator.py | 3 + tests/test_ensure_channel_first.py | 1 + tests/test_ensure_channel_firstd.py | 1 + tests/test_ensure_tuple.py | 1 + tests/test_ensure_type.py | 1 + tests/test_ensure_typed.py | 1 + tests/test_enum_bound_interp.py | 1 + tests/test_eval_mode.py | 1 + .../test_evenly_divisible_all_gather_dist.py | 1 + tests/test_factorized_increase.py | 1 + tests/test_factorized_reduce.py | 1 + tests/test_fastmri_reader.py | 1 + tests/test_fft_utils.py | 1 + tests/test_fg_bg_to_indices.py | 1 + tests/test_fg_bg_to_indicesd.py | 1 + tests/test_file_basename.py | 1 + tests/test_fill_holes.py | 1 + tests/test_fill_holesd.py | 1 + tests/test_fl_exchange_object.py | 1 + tests/test_fl_monai_algo.py | 7 +- tests/test_fl_monai_algo_dist.py | 1 + tests/test_fl_monai_algo_stats.py | 1 + tests/test_flatten_sub_keysd.py | 1 + tests/test_flexible_unet.py | 138 ++----- tests/test_flip.py | 1 + tests/test_flipd.py | 1 + tests/test_focal_loss.py | 3 +- tests/test_folder_layout.py | 1 + tests/test_foreground_mask.py | 1 + tests/test_foreground_maskd.py | 1 + tests/test_fourier.py | 1 + tests/test_fpn_block.py | 2 + tests/test_freeze_layers.py | 1 + tests/test_from_engine_hovernet.py | 1 + tests/test_fullyconnectednet.py | 1 + tests/test_gaussian.py | 1 + tests/test_gaussian_filter.py | 6 +- tests/test_gaussian_sharpen.py | 1 + tests/test_gaussian_sharpend.py | 1 + tests/test_gaussian_smooth.py | 1 + tests/test_gaussian_smoothd.py | 1 + tests/test_gdsdataset.py | 2 + tests/test_generalized_dice_focal_loss.py | 15 +- tests/test_generalized_dice_loss.py | 1 + .../test_generalized_wasserstein_dice_loss.py | 2 + tests/test_generate_distance_map.py | 1 + tests/test_generate_distance_mapd.py | 1 + tests/test_generate_instance_border.py | 1 + tests/test_generate_instance_borderd.py | 1 + tests/test_generate_instance_centroid.py | 1 + tests/test_generate_instance_centroidd.py | 1 + tests/test_generate_instance_contour.py | 1 + tests/test_generate_instance_contourd.py | 1 + tests/test_generate_instance_type.py | 1 + tests/test_generate_instance_typed.py | 1 + ...est_generate_label_classes_crop_centers.py | 1 + tests/test_generate_param_groups.py | 1 + ...est_generate_pos_neg_label_crop_centers.py | 1 + tests/test_generate_spatial_bounding_box.py | 1 + tests/test_generate_succinct_contour.py | 1 + tests/test_generate_succinct_contourd.py | 1 + tests/test_generate_watershed_markers.py | 1 + tests/test_generate_watershed_markersd.py | 1 + tests/test_generate_watershed_mask.py | 1 + tests/test_generate_watershed_maskd.py | 1 + tests/test_generator.py | 1 + tests/test_get_equivalent_dtype.py | 1 + tests/test_get_extreme_points.py | 1 + tests/test_get_layers.py | 2 + tests/test_get_package_version.py | 1 + tests/test_get_unique_labels.py | 1 + tests/test_gibbs_noise.py | 1 + tests/test_gibbs_noised.py | 1 + tests/test_giou_loss.py | 1 + tests/test_global_mutual_information_loss.py | 42 ++- tests/test_globalnet.py | 2 + tests/test_gmm.py | 1 + tests/test_grid_dataset.py | 1 + tests/test_grid_distortion.py | 1 + tests/test_grid_distortiond.py | 1 + tests/test_grid_patch.py | 1 + tests/test_grid_patchd.py | 1 + tests/test_grid_pull.py | 1 + tests/test_grid_split.py | 1 + tests/test_grid_splitd.py | 1 + tests/test_handler_checkpoint_loader.py | 1 + tests/test_handler_checkpoint_saver.py | 1 + tests/test_handler_classification_saver.py | 1 + .../test_handler_classification_saver_dist.py | 1 + tests/test_handler_clearml_image.py | 1 + tests/test_handler_clearml_stats.py | 1 + tests/test_handler_confusion_matrix_dist.py | 1 + tests/test_handler_decollate_batch.py | 1 + tests/test_handler_early_stop.py | 3 + tests/test_handler_garbage_collector.py | 1 + tests/test_handler_ignite_metric.py | 1 + tests/test_handler_logfile.py | 1 + tests/test_handler_lr_scheduler.py | 1 + tests/test_handler_metric_logger.py | 1 + tests/test_handler_metrics_reloaded.py | 2 + tests/test_handler_metrics_saver.py | 1 + tests/test_handler_metrics_saver_dist.py | 1 + tests/test_handler_mlflow.py | 2 + tests/test_handler_nvtx.py | 1 + tests/test_handler_panoptic_quality.py | 1 + tests/test_handler_parameter_scheduler.py | 3 + tests/test_handler_post_processing.py | 1 + tests/test_handler_prob_map_producer.py | 3 + tests/test_handler_regression_metrics.py | 1 + tests/test_handler_regression_metrics_dist.py | 4 + tests/test_handler_rocauc.py | 1 + tests/test_handler_rocauc_dist.py | 1 + tests/test_handler_smartcache.py | 1 + tests/test_handler_stats.py | 2 + tests/test_handler_tb_image.py | 1 + tests/test_handler_tb_stats.py | 2 + tests/test_handler_validation.py | 2 + tests/test_hardnegsampler.py | 1 + tests/test_hashing.py | 2 + tests/test_hausdorff_distance.py | 1 + tests/test_hausdorff_loss.py | 24 +- tests/test_header_correct.py | 1 + tests/test_highresnet.py | 1 + tests/test_hilbert_transform.py | 3 + tests/test_histogram_normalize.py | 1 + tests/test_histogram_normalized.py | 1 + tests/test_hovernet.py | 1 + ...t_hovernet_instance_map_post_processing.py | 1 + ..._hovernet_instance_map_post_processingd.py | 1 + tests/test_hovernet_loss.py | 2 + ...t_hovernet_nuclear_type_post_processing.py | 1 + ..._hovernet_nuclear_type_post_processingd.py | 1 + tests/test_identity.py | 1 + tests/test_identityd.py | 1 + tests/test_image_dataset.py | 2 + tests/test_image_filter.py | 5 + tests/test_image_rw.py | 4 + tests/test_img2tensorboard.py | 1 + tests/test_init_reader.py | 1 + tests/test_integration_autorunner.py | 1 + tests/test_integration_bundle_run.py | 3 + tests/test_integration_classification_2d.py | 2 + tests/test_integration_determinism.py | 3 + tests/test_integration_fast_train.py | 1 + tests/test_integration_gpu_customization.py | 1 + tests/test_integration_lazy_samples.py | 1 + tests/test_integration_nnunetv2_runner.py | 1 + tests/test_integration_segmentation_3d.py | 1 + tests/test_integration_sliding_window.py | 1 + tests/test_integration_stn.py | 1 + tests/test_integration_unet_2d.py | 3 + tests/test_integration_workers.py | 1 + tests/test_integration_workflows.py | 3 + tests/test_integration_workflows_gan.py | 1 + tests/test_intensity_stats.py | 1 + tests/test_intensity_statsd.py | 1 + tests/test_inverse_array.py | 1 + tests/test_invert.py | 1 + tests/test_invertd.py | 1 + tests/test_is_supported_format.py | 1 + tests/test_iterable_dataset.py | 2 + tests/test_itk_torch_bridge.py | 2 + tests/test_itk_writer.py | 1 + tests/test_k_space_spike_noise.py | 1 + tests/test_k_space_spike_noised.py | 1 + .../test_keep_largest_connected_component.py | 1 + .../test_keep_largest_connected_componentd.py | 1 + tests/test_kspace_mask.py | 1 + tests/test_label_filter.py | 1 + tests/test_label_filterd.py | 1 + tests/test_label_quality_score.py | 1 + tests/test_label_to_contour.py | 1 + tests/test_label_to_contourd.py | 1 + tests/test_label_to_mask.py | 1 + tests/test_label_to_maskd.py | 1 + tests/test_lambda.py | 1 + tests/test_lambdad.py | 1 + tests/test_lesion_froc.py | 1 + tests/test_list_data_collate.py | 1 + tests/test_list_to_dict.py | 1 + tests/test_lltm.py | 1 + tests/test_lmdbdataset.py | 2 + tests/test_lmdbdataset_dist.py | 2 + tests/test_load_decathlon_datalist.py | 1 + tests/test_load_image.py | 2 + tests/test_load_imaged.py | 3 + tests/test_load_spacing_orientation.py | 1 + tests/test_loader_semaphore.py | 1 + ...local_normalized_cross_correlation_loss.py | 1 + tests/test_localnet.py | 1 + tests/test_localnet_block.py | 3 + tests/test_look_up_option.py | 1 + tests/test_loss_metric.py | 1 + tests/test_lr_finder.py | 1 + tests/test_lr_scheduler.py | 2 + tests/test_make_nifti.py | 1 + tests/test_map_binary_to_indices.py | 1 + tests/test_map_classes_to_indices.py | 1 + tests/test_map_label_value.py | 1 + tests/test_map_label_valued.py | 1 + tests/test_map_transform.py | 2 + tests/test_mask_intensity.py | 1 + tests/test_mask_intensityd.py | 1 + tests/test_masked_dice_loss.py | 1 + tests/test_masked_loss.py | 1 + tests/test_masked_patch_wsi_dataset.py | 3 + tests/test_matshow3d.py | 1 + tests/test_mean_ensemble.py | 1 + tests/test_mean_ensembled.py | 1 + tests/test_median_filter.py | 20 +- tests/test_median_smooth.py | 1 + tests/test_median_smoothd.py | 1 + tests/test_mednistdataset.py | 6 +- tests/test_meta_affine.py | 1 + tests/test_meta_tensor.py | 1 + tests/test_metatensor_integration.py | 1 + tests/test_metrics_reloaded.py | 1 + tests/test_milmodel.py | 1 + tests/test_mlp.py | 1 + tests/test_mmar_download.py | 1 + tests/test_module_list.py | 1 + tests/test_monai_env_vars.py | 1 + tests/test_monai_utils_misc.py | 13 +- tests/test_mri_utils.py | 1 + tests/test_multi_scale.py | 30 +- tests/test_net_adapter.py | 1 + tests/test_network_consistency.py | 1 + tests/test_nifti_endianness.py | 1 + tests/test_nifti_header_revise.py | 1 + tests/test_nifti_rw.py | 1 + tests/test_normalize_intensity.py | 1 + tests/test_normalize_intensityd.py | 1 + tests/test_npzdictitemdataset.py | 1 + tests/test_nrrd_reader.py | 1 + tests/test_nuclick_transforms.py | 9 + tests/test_numpy_reader.py | 1 + tests/test_nvtx_decorator.py | 1 + tests/test_nvtx_transform.py | 1 + tests/test_occlusion_sensitivity.py | 2 + tests/test_one_of.py | 12 + tests/test_optional_import.py | 28 +- tests/test_ori_ras_lps.py | 1 + tests/test_orientation.py | 1 + tests/test_orientationd.py | 1 + tests/test_p3d_block.py | 1 + tests/test_pad_collation.py | 4 +- tests/test_pad_mode.py | 1 + tests/test_partition_dataset.py | 1 + tests/test_partition_dataset_classes.py | 1 + tests/test_patch_dataset.py | 1 + tests/test_patch_inferer.py | 1 + tests/test_patch_wsi_dataset.py | 3 + tests/test_patchembedding.py | 28 ++ tests/test_pathology_he_stain.py | 2 + tests/test_pathology_he_stain_dict.py | 2 + tests/test_pathology_prob_nms.py | 1 + tests/test_perceptual_loss.py | 48 ++- tests/test_persistentdataset.py | 2 + tests/test_persistentdataset_dist.py | 3 + tests/test_phl_cpu.py | 1 + tests/test_phl_cuda.py | 1 + tests/test_pil_reader.py | 1 + tests/test_plot_2d_or_3d_image.py | 1 + tests/test_png_rw.py | 1 + tests/test_polyval.py | 1 + tests/test_prepare_batch_default.py | 102 ++---- tests/test_prepare_batch_default_dist.py | 2 + tests/test_prepare_batch_extra_input.py | 2 + tests/test_prepare_batch_hovernet.py | 2 + tests/test_preset_filters.py | 6 + tests/test_print_info.py | 1 + tests/test_print_transform_backends.py | 1 + tests/test_probnms.py | 1 + tests/test_probnmsd.py | 1 + tests/test_profiling.py | 1 + tests/test_pytorch_version_after.py | 1 + tests/test_query_memory.py | 1 + tests/test_quicknat.py | 1 + tests/test_rand_adjust_contrast.py | 1 + tests/test_rand_adjust_contrastd.py | 1 + tests/test_rand_affine.py | 10 +- tests/test_rand_affine_grid.py | 1 + tests/test_rand_affined.py | 14 +- tests/test_rand_axis_flip.py | 1 + tests/test_rand_axis_flipd.py | 1 + tests/test_rand_bias_field.py | 1 + tests/test_rand_bias_fieldd.py | 1 + tests/test_rand_coarse_dropout.py | 1 + tests/test_rand_coarse_dropoutd.py | 1 + tests/test_rand_coarse_shuffle.py | 1 + tests/test_rand_coarse_shuffled.py | 1 + tests/test_rand_crop_by_label_classes.py | 1 + tests/test_rand_crop_by_label_classesd.py | 1 + tests/test_rand_crop_by_pos_neg_label.py | 1 + tests/test_rand_crop_by_pos_neg_labeld.py | 1 + tests/test_rand_cucim_dict_transform.py | 1 + tests/test_rand_cucim_transform.py | 1 + tests/test_rand_deform_grid.py | 1 + tests/test_rand_elastic_2d.py | 1 + tests/test_rand_elastic_3d.py | 1 + tests/test_rand_elasticd_2d.py | 1 + tests/test_rand_elasticd_3d.py | 1 + tests/test_rand_flip.py | 1 + tests/test_rand_flipd.py | 1 + tests/test_rand_gaussian_noise.py | 13 +- tests/test_rand_gaussian_noised.py | 15 +- tests/test_rand_gaussian_sharpen.py | 1 + tests/test_rand_gaussian_sharpend.py | 1 + tests/test_rand_gaussian_smooth.py | 1 + tests/test_rand_gaussian_smoothd.py | 1 + tests/test_rand_gibbs_noise.py | 10 + tests/test_rand_gibbs_noised.py | 9 + tests/test_rand_grid_distortion.py | 1 + tests/test_rand_grid_distortiond.py | 1 + tests/test_rand_grid_patch.py | 1 + tests/test_rand_grid_patchd.py | 1 + tests/test_rand_histogram_shift.py | 1 + tests/test_rand_histogram_shiftd.py | 1 + tests/test_rand_k_space_spike_noise.py | 1 + tests/test_rand_k_space_spike_noised.py | 1 + tests/test_rand_lambda.py | 1 + tests/test_rand_lambdad.py | 1 + tests/test_rand_rician_noise.py | 1 + tests/test_rand_rician_noised.py | 1 + tests/test_rand_rotate.py | 3 + tests/test_rand_rotate90.py | 1 + tests/test_rand_rotate90d.py | 1 + tests/test_rand_rotated.py | 2 + tests/test_rand_scale_intensity.py | 1 + tests/test_rand_scale_intensity_fixed_mean.py | 1 + .../test_rand_scale_intensity_fixed_meand.py | 1 + tests/test_rand_scale_intensityd.py | 1 + tests/test_rand_shift_intensity.py | 1 + tests/test_rand_shift_intensityd.py | 1 + tests/test_rand_simulate_low_resolution.py | 1 + tests/test_rand_simulate_low_resolutiond.py | 1 + tests/test_rand_spatial_crop_samplesd.py | 1 + tests/test_rand_std_shift_intensity.py | 1 + tests/test_rand_std_shift_intensityd.py | 1 + tests/test_rand_weighted_cropd.py | 1 + tests/test_rand_zoom.py | 1 + tests/test_rand_zoomd.py | 1 + tests/test_randidentity.py | 2 + tests/test_random_order.py | 4 + tests/test_randomizable.py | 2 + tests/test_randomizable_transform_type.py | 2 + tests/test_randtorchvisiond.py | 1 + tests/test_rankfilter_dist.py | 2 + tests/test_recon_net_utils.py | 1 + ...est_reference_based_normalize_intensity.py | 1 + tests/test_reference_based_spatial_cropd.py | 1 + tests/test_reference_resolver.py | 1 + tests/test_reg_loss_integration.py | 2 + tests/test_regularization.py | 112 ++++++ tests/test_regunet.py | 1 + tests/test_regunet_block.py | 3 + tests/test_remove_repeated_channel.py | 1 + tests/test_remove_repeated_channeld.py | 1 + tests/test_remove_small_objects.py | 1 + tests/test_repeat_channel.py | 1 + tests/test_repeat_channeld.py | 1 + tests/test_replace_module.py | 1 + tests/test_require_pkg.py | 4 + tests/test_resample.py | 1 + tests/test_resample_backends.py | 1 + tests/test_resample_datalist.py | 1 + tests/test_resample_to_match.py | 1 + tests/test_resample_to_matchd.py | 1 + tests/test_resampler.py | 1 + tests/test_resize.py | 11 +- tests/test_resize_with_pad_or_crop.py | 1 + tests/test_resize_with_pad_or_cropd.py | 1 + tests/test_resized.py | 10 +- tests/test_resnet.py | 40 ++- tests/test_retinanet.py | 1 + tests/test_retinanet_detector.py | 2 + tests/test_retinanet_predict_utils.py | 3 + tests/test_rotate.py | 2 + tests/test_rotate90.py | 3 + tests/test_rotate90d.py | 1 + tests/test_rotated.py | 3 + tests/test_safe_dtype_range.py | 1 + tests/test_saliency_inferer.py | 1 + tests/test_sample_slices.py | 1 + tests/test_sampler_dist.py | 1 + tests/test_save_classificationd.py | 1 + tests/test_save_image.py | 1 + tests/test_save_imaged.py | 3 + tests/test_save_state.py | 1 + tests/test_savitzky_golay_filter.py | 4 + tests/test_savitzky_golay_smooth.py | 1 + tests/test_savitzky_golay_smoothd.py | 1 + tests/test_scale_intensity.py | 1 + tests/test_scale_intensity_fixed_mean.py | 1 + tests/test_scale_intensity_range.py | 1 + .../test_scale_intensity_range_percentiles.py | 1 + ...test_scale_intensity_range_percentilesd.py | 1 + tests/test_scale_intensity_ranged.py | 1 + tests/test_scale_intensityd.py | 1 + tests/test_se_block.py | 1 + tests/test_se_blocks.py | 2 + tests/test_seg_loss_integration.py | 2 + tests/test_segresnet.py | 2 + tests/test_segresnet_block.py | 1 + tests/test_segresnet_ds.py | 1 + tests/test_select_cross_validation_folds.py | 1 + tests/test_select_itemsd.py | 1 + tests/test_selfattention.py | 1 + tests/test_senet.py | 2 + tests/test_separable_filter.py | 1 + tests/test_set_determinism.py | 2 + tests/test_set_visible_devices.py | 4 +- tests/test_shift_intensity.py | 1 + tests/test_shift_intensityd.py | 1 + tests/test_shuffle_buffer.py | 1 + tests/test_signal_continuouswavelet.py | 1 + tests/test_signal_fillempty.py | 2 + tests/test_signal_fillemptyd.py | 2 + tests/test_signal_rand_add_gaussiannoise.py | 2 + tests/test_signal_rand_add_sine.py | 2 + tests/test_signal_rand_add_sine_partial.py | 2 + tests/test_signal_rand_add_squarepulse.py | 2 + ...est_signal_rand_add_squarepulse_partial.py | 2 + tests/test_signal_rand_drop.py | 2 + tests/test_signal_rand_scale.py | 2 + tests/test_signal_rand_shift.py | 2 + tests/test_signal_remove_frequency.py | 2 + tests/test_simple_aspp.py | 1 + tests/test_simulatedelay.py | 1 + tests/test_simulatedelayd.py | 1 + tests/test_skip_connection.py | 1 + tests/test_slice_inferer.py | 1 + tests/test_sliding_patch_wsi_dataset.py | 3 + .../test_sliding_window_hovernet_inference.py | 1 + tests/test_sliding_window_inference.py | 2 + tests/test_sliding_window_splitter.py | 1 + tests/test_smartcachedataset.py | 1 + tests/test_smooth_field.py | 1 + tests/test_soft_clip.py | 125 +++++++ tests/test_some_of.py | 6 + tests/test_spacing.py | 1 + tests/test_spacingd.py | 1 + tests/test_spade_autoencoderkl.py | 49 ++- tests/test_spatial_combine_transforms.py | 1 + tests/test_spatial_resample.py | 1 + tests/test_spatial_resampled.py | 10 +- tests/test_spectral_loss.py | 1 + tests/test_splitdim.py | 1 + tests/test_squeeze_unsqueeze.py | 1 + tests/test_squeezedim.py | 1 + tests/test_squeezedimd.py | 1 + tests/test_ssim_loss.py | 1 + tests/test_ssim_metric.py | 1 + tests/test_state_cacher.py | 1 + tests/test_std_shift_intensity.py | 1 + tests/test_std_shift_intensityd.py | 1 + tests/test_str2bool.py | 1 + tests/test_str2list.py | 1 + tests/test_subpixel_upsample.py | 1 + tests/test_sure_loss.py | 72 ++++ tests/test_surface_dice.py | 1 + tests/test_surface_distance.py | 1 + tests/test_swin_unetr.py | 1 + tests/test_synthetic.py | 1 + tests/test_tciadataset.py | 5 +- tests/test_testtimeaugmentation.py | 1 + tests/test_text_encoding.py | 1 + tests/test_thread_buffer.py | 1 + tests/test_threadcontainer.py | 1 + tests/test_threshold_intensity.py | 1 + tests/test_threshold_intensityd.py | 1 + tests/test_timedcall_dist.py | 1 + tests/test_to_contiguous.py | 1 + tests/test_to_cupy.py | 1 + tests/test_to_cupyd.py | 1 + tests/test_to_device.py | 1 + tests/test_to_deviced.py | 1 + tests/test_to_from_meta_tensord.py | 1 + tests/test_to_numpy.py | 3 +- tests/test_to_numpyd.py | 1 + tests/test_to_onehot.py | 1 + tests/test_to_pil.py | 1 + tests/test_to_pild.py | 1 + tests/test_to_tensor.py | 1 + tests/test_to_tensord.py | 1 + tests/test_torchscript_utils.py | 2 + tests/test_torchvision.py | 1 + tests/test_torchvision_fc_model.py | 2 + tests/test_torchvisiond.py | 1 + tests/test_traceable_transform.py | 2 + tests/test_train_mode.py | 1 + tests/test_trainable_bilateral.py | 2 + tests/test_trainable_joint_bilateral.py | 2 + tests/test_transchex.py | 1 + tests/test_transform.py | 2 + tests/test_transformerblock.py | 1 + tests/test_transpose.py | 1 + tests/test_transposed.py | 1 + tests/test_tversky_loss.py | 12 +- ...est_ultrasound_confidence_map_transform.py | 184 +++------- tests/test_unet.py | 1 + tests/test_unetr.py | 1 + tests/test_unetr_block.py | 3 + tests/test_unified_focal_loss.py | 1 + tests/test_upsample_block.py | 1 + tests/test_utils_pytorch_numpy_unification.py | 1 + tests/test_varautoencoder.py | 1 + tests/test_varnet.py | 1 + tests/test_version.py | 1 + tests/test_video_datasets.py | 1 + tests/test_vis_cam.py | 1 + tests/test_vis_gradbased.py | 2 + tests/test_vis_gradcam.py | 2 + tests/test_vit.py | 100 ++---- tests/test_vitautoenc.py | 98 ++--- tests/test_vnet.py | 1 + tests/test_vote_ensemble.py | 1 + tests/test_vote_ensembled.py | 1 + tests/test_voxelmorph.py | 1 + tests/test_warp.py | 1 + tests/test_watershed.py | 1 + tests/test_watershedd.py | 1 + tests/test_weight_init.py | 1 + tests/test_weighted_random_sampler_dist.py | 1 + tests/test_with_allow_missing_keys.py | 1 + tests/test_write_metrics_reports.py | 1 + tests/test_wsi_sliding_window_splitter.py | 2 + tests/test_wsireader.py | 4 + tests/test_zarr_avg_merger.py | 1 + tests/test_zipdataset.py | 2 + tests/test_zoom.py | 1 + tests/test_zoom_affine.py | 1 + tests/test_zoomd.py | 1 + tests/testing_data/fl_infer_properties.json | 67 ++++ tests/testing_data/integration_answers.py | 56 +++ tests/utils.py | 2 + 810 files changed, 4480 insertions(+), 1126 deletions(-) create mode 100644 monai/losses/barlow_twins.py create mode 100644 monai/losses/sure_loss.py create mode 100644 monai/networks/layers/conjugate_gradient.py create mode 100644 monai/transforms/regularization/__init__.py create mode 100644 monai/transforms/regularization/array.py create mode 100644 monai/transforms/regularization/dictionary.py create mode 100644 tests/test_barlow_twins_loss.py create mode 100644 tests/test_clip_intensity_percentiles.py create mode 100644 tests/test_clip_intensity_percentilesd.py create mode 100644 tests/test_conjugate_gradient.py create mode 100644 tests/test_regularization.py create mode 100644 tests/test_soft_clip.py create mode 100644 tests/test_sure_loss.py create mode 100644 tests/testing_data/fl_infer_properties.json diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index 59c7d070b4..6f3b1c293d 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v3.0.2 + uses: peter-evans/slash-command-dispatch@v4.0.0 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index a387c77ebd..367a24cbde 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -26,7 +26,7 @@ jobs: steps: - if: runner.os == 'windows' name: Config pagefile (Windows only) - uses: al-cheb/configure-pagefile-action@v1.3 + uses: al-cheb/configure-pagefile-action@v1.4 with: minimum-size: 8GB maximum-size: 16GB diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml index 84666204a9..bd45bc8d1e 100644 --- a/.github/workflows/cron-ngc-bundle.yml +++ b/.github/workflows/cron-ngc-bundle.yml @@ -26,7 +26,7 @@ jobs: id: pip-cache run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ~/.cache/pip diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index e981280ff9..0f9e6cd480 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -19,18 +19,18 @@ jobs: - "PTLATEST+CUDA121" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - - environment: PT191+CUDA113 - pytorch: "torch==1.9.1 torchvision==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113" - base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - environment: PT110+CUDA113 pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - environment: PT113+CUDA113 pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PTLATEST+CUDA121 - pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118" + - environment: PT113+CUDA122 + pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121" base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.2 + - environment: PTLATEST+CUDA124 + pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121" + base: "nvcr.io/nvidia/pytorch:24.03-py3" # CUDA 12.4 container: image: ${{ matrix.base }} options: "--gpus all" @@ -67,7 +67,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false files: ./coverage.xml @@ -76,7 +76,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:22.10", "pytorch:23.08"] + container: ["pytorch:23.08", "pytorch:24.03"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -111,7 +111,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false files: ./coverage.xml @@ -121,7 +121,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:23.08"] + container: ["pytorch:24.03"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -212,7 +212,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false files: ./coverage.xml @@ -221,7 +221,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:23.08-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:24.03-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 229ae675f5..65716f86f9 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -17,7 +17,8 @@ jobs: versioning_dev: # compute versioning file from python setup.py # upload as artifact - if: github.repository == 'Project-MONAI/MONAI' + # if: github.repository == 'Project-MONAI/MONAI' + if: ${{ false }} # disable docker build job project-monai/monai#7450 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -47,8 +48,8 @@ jobs: rm -rf {*,.[^.]*} docker_build_dev: - # builds projectmonai/monai:latest - if: github.repository == 'Project-MONAI/MONAI' + # if: github.repository == 'Project-MONAI/MONAI' + if: ${{ false }} # disable docker build job project-monai/monai#7450 needs: versioning_dev runs-on: ubuntu-latest steps: @@ -62,7 +63,7 @@ jobs: - name: docker_build shell: bash run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; docker --version # get tag info for versioning cat _version.py diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 952b2d8deb..c82530a551 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -22,7 +22,7 @@ jobs: id: pip-cache run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -71,7 +71,7 @@ jobs: run: ./runtests.sh --build --net - name: Add reaction - uses: peter-evans/create-or-update-comment@v3 + uses: peter-evans/create-or-update-comment@v4 with: token: ${{ secrets.PR_MAINTAIN }} repository: ${{ github.event.client_payload.github.payload.repository.full_name }} @@ -95,7 +95,7 @@ jobs: id: pip-cache run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -151,7 +151,7 @@ jobs: python -m tests.test_integration_gpu_customization - name: Add reaction - uses: peter-evans/create-or-update-comment@v3 + uses: peter-evans/create-or-update-comment@v4 with: token: ${{ secrets.PR_MAINTAIN }} repository: ${{ github.event.client_payload.github.payload.repository.full_name }} diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 0baef949f0..f83d52f8e3 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -29,10 +29,6 @@ jobs: - "PT210+CUDA121DOCKER" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - - environment: PT19+CUDA114DOCKER - # 21.10: 1.10.0a0+0aef44c - pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:21.10-py3" - environment: PT110+CUDA111 pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu111" base: "nvcr.io/nvidia/cuda:11.1.1-devel-ubuntu18.04" @@ -47,6 +43,10 @@ jobs: # 23.08: 2.1.0a0+29c30b1 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error base: "nvcr.io/nvidia/pytorch:23.08-py3" + - environment: PT210+CUDA121DOCKER + # 24.03: 2.3.0a0+40ec155e58.nv24.3 + pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error + base: "nvcr.io/nvidia/pytorch:24.03-py3" container: image: ${{ matrix.base }} options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6 @@ -137,6 +137,6 @@ jobs: shell: bash - name: Upload coverage if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }} - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: files: ./coverage.xml diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 7b7930bdf5..bbe7579774 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -44,7 +44,7 @@ jobs: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT shell: bash - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ${{ steps.pip-cache.outputs.dir }} @@ -90,7 +90,7 @@ jobs: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT shell: bash - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ${{ steps.pip-cache.outputs.dir }} @@ -135,7 +135,7 @@ jobs: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT shell: bash - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ${{ steps.pip-cache.outputs.dir }} diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 29a79759e0..b7f2cfb9db 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -36,7 +36,7 @@ jobs: run: | echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ~/.cache/pip @@ -62,7 +62,7 @@ jobs: steps: - if: runner.os == 'windows' name: Config pagefile (Windows only) - uses: al-cheb/configure-pagefile-action@v1.3 + uses: al-cheb/configure-pagefile-action@v1.4 with: minimum-size: 8GB maximum-size: 16GB @@ -83,7 +83,7 @@ jobs: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT shell: bash - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: ${{ steps.pip-cache.outputs.dir }} @@ -136,7 +136,7 @@ jobs: run: | echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -217,7 +217,7 @@ jobs: run: | echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a03d2cea6c..c134724665 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,7 +27,7 @@ jobs: python -m pip install --user --upgrade setuptools wheel - name: Build and test source archive and wheel file run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; git fetch --depth=1 origin +refs/tags/*:refs/tags/* root_dir=$PWD echo "$root_dir" @@ -102,7 +102,7 @@ jobs: python-version: '3.9' - shell: bash run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; git describe python -m pip install --user --upgrade setuptools wheel python setup.py build @@ -143,7 +143,7 @@ jobs: RELEASE_VERSION: ${{ steps.versioning.outputs.tag }} shell: bash run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; # get tag info for versioning mv _version.py monai/ # version checks diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 82394a86dd..c6ad243b81 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -35,7 +35,7 @@ jobs: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip if: ${{ startsWith(github.ref, 'refs/heads/dev') }} - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -68,7 +68,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false files: ./coverage.xml @@ -91,7 +91,7 @@ jobs: run: | echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -100,7 +100,7 @@ jobs: key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --upgrade pip wheel python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu @@ -111,7 +111,7 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build --quick --min coverage xml --ignore-errors - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false files: ./coverage.xml @@ -128,7 +128,7 @@ jobs: run: | echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - name: cache for pip - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache with: path: | @@ -146,7 +146,7 @@ jobs: - name: Install the default branch with build (dev branch only) if: github.ref == 'refs/heads/dev' run: | - rm -rf /opt/hostedtoolcache + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI python -c 'import monai; monai.config.print_config()' - name: Get the test cases (dev branch only) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 14b41bbeb8..b71a2bac43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: end-of-file-fixer - id: mixed-line-ending - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.261 + rev: v0.3.5 hooks: - id: ruff args: @@ -58,7 +58,7 @@ repos: name: Unused noqa additional_dependencies: - flake8>=3.8.1 - - flake8-bugbear + - flake8-bugbear<=24.2.6 - flake8-comprehensions - pep8-naming exclude: | diff --git a/Dockerfile b/Dockerfile index cb1300ea90..fc97227351 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,13 +11,25 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.08-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" +# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431) +RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ + cd /opt && \ + git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \ + pip wheel numcodecs && \ + rm -r /opt/*.whl && \ + rm -rf /opt/numcodecs; \ + fi + WORKDIR /opt/monai +# remove opencv-python before opencv-python-headless installation +RUN pip uninstall -y opencv && rm /usr/local/lib/python3.10/dist-packages/cv2 -r + # install full deps COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/ RUN cp /tmp/requirements.txt /tmp/req.bak \ diff --git a/docs/requirements.txt b/docs/requirements.txt index e5bedf8552..5acc437391 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,7 +22,7 @@ sphinx-autodoc-typehints==1.11.1 pandas einops transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 -mlflow>=1.28.0 +mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 tensorboardX imagecodecs; platform_system == "Linux" or platform_system == "Darwin" diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c1e3d5cbe9..c932879b5a 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `_mode_` specifies the operating mode when the component is instantiated or the callable is called. it currently supports the following values: - `"default"` (default) -- return the return value of ``_target_(**kwargs)`` - - `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often - useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to - call it with additional arguments later). + - `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a + partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function + that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call + it with additional arguments later. - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). diff --git a/docs/source/losses.rst b/docs/source/losses.rst index e929e9d605..ba794af3eb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -73,6 +73,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`BarlowTwinsLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BarlowTwinsLoss + :members: + `HausdorffDTLoss` ~~~~~~~~~~~~~~~~~ .. autoclass:: HausdorffDTLoss @@ -134,6 +139,11 @@ Reconstruction Losses .. autoclass:: JukeboxLoss :members: +`SURELoss` +~~~~~~~~~~ +.. autoclass:: SURELoss + :members: + Loss Wrappers ------------- diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 556bf12d50..c51f5c88b1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -426,6 +426,11 @@ Layers .. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer :members: +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -504,6 +509,11 @@ Nets .. autoclass:: ResNet :members: +`ResNetFeatures` +~~~~~~~~~~~~~~~~ +.. autoclass:: ResNetFeatures + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8990e7991d..8bd5bfd99f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -309,6 +309,12 @@ Intensity :members: :special-members: __call__ +`ClipIntensityPercentiles` +"""""""""""""""""""""""""" +.. autoclass:: ClipIntensityPercentiles + :members: + :special-members: __call__ + `RandScaleIntensity` """""""""""""""""""" .. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensity.png @@ -661,6 +667,27 @@ Post-processing :members: :special-members: __call__ +Regularization +^^^^^^^^^^^^^^ + +`CutMix` +"""""""" +.. autoclass:: CutMix + :members: + :special-members: __call__ + +`CutOut` +"""""""" +.. autoclass:: CutOut + :members: + :special-members: __call__ + +`MixUp` +""""""" +.. autoclass:: MixUp + :members: + :special-members: __call__ + Signal ^^^^^^^ @@ -1384,6 +1411,12 @@ Intensity (Dict) :members: :special-members: __call__ +`ClipIntensityPercentilesd` +""""""""""""""""""""""""""" +.. autoclass:: ClipIntensityPercentilesd + :members: + :special-members: __call__ + `RandScaleIntensityd` """"""""""""""""""""" .. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensityd.png @@ -1707,6 +1740,27 @@ Post-processing (Dict) :members: :special-members: __call__ +Regularization (Dict) +^^^^^^^^^^^^^^^^^^^^^ + +`CutMixd` +""""""""" +.. autoclass:: CutMixd + :members: + :special-members: __call__ + +`CutOutd` +""""""""" +.. autoclass:: CutOutd + :members: + :special-members: __call__ + +`MixUpd` +"""""""" +.. autoclass:: MixUpd + :members: + :special-members: __call__ + Signal (Dict) ^^^^^^^^^^^^^ diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst index f4d02a483f..650d45db71 100644 --- a/docs/source/transforms_idx.rst +++ b/docs/source/transforms_idx.rst @@ -74,6 +74,16 @@ Post-processing post.array post.dictionary +Regularization +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: _gen + :nosignatures: + + regularization.array + regularization.dictionary + Signal ^^^^^^ diff --git a/monai/__init__.py b/monai/__init__.py index 638220f6df..eb05ac993d 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -83,6 +83,11 @@ from .utils.tf32 import detect_default_tf32 detect_default_tf32() + import torch + + # workaround related to https://github.com/Project-MONAI/MONAI/issues/7575 + if hasattr(torch.cuda.device_count, "cache_clear"): + torch.cuda.device_count.cache_clear() except BaseException: from .utils.misc import MONAIEnvVars diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index e4c2d908b7..05c961f999 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -85,6 +85,7 @@ class AutoRunner: can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer. mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None. + mlflow_experiment_name: the name of the experiment in MLflow server. kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage. @@ -212,6 +213,7 @@ def __init__( templates_path_or_url: str | None = None, allow_skip: bool = True, mlflow_tracking_uri: str | None = None, + mlflow_experiment_name: str | None = None, **kwargs: Any, ): if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")): @@ -253,6 +255,7 @@ def __init__( self.hpo = hpo and has_nni self.hpo_backend = hpo_backend self.mlflow_tracking_uri = mlflow_tracking_uri + self.mlflow_experiment_name = mlflow_experiment_name self.kwargs = deepcopy(kwargs) # parse input config for AutoRunner param overrides @@ -268,7 +271,13 @@ def __init__( if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool): setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"] - for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config + for param in [ + "algos", + "hpo_backend", + "templates_path_or_url", + "mlflow_tracking_uri", + "mlflow_experiment_name", + ]: # override from config if param in self.data_src_cfg: setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"] @@ -289,9 +298,13 @@ def __init__( pass # inspect and update folds - num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) + self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) if "num_fold" in self.data_src_cfg: num_fold = int(self.data_src_cfg["num_fold"]) # override from config + logger.info(f"Setting num_fold {num_fold} based on the input config.") + else: + num_fold = self.max_fold + logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -389,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int: if len(fold_list) > 0: num_fold = max(fold_list) + 1 - logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") + logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.") + # check if every fold is present + if len(set(fold_list)) != num_fold: + raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}") elif "validation" in datalist and len(datalist["validation"]) > 0: logger.info("No fold numbers provided, attempting to use a single fold based on the validation key") # update the datalist file @@ -483,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner: if num_fold <= 0: raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}") + if num_fold > self.max_fold + 1: + # Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1 + raise ValueError( + f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}." + ) self.num_fold = num_fold return self @@ -813,6 +834,7 @@ def run(self): data_stats_filename=self.datastats_filename, data_src_cfg_name=self.data_src_cfg_name, mlflow_tracking_uri=self.mlflow_tracking_uri, + mlflow_experiment_name=self.mlflow_experiment_name, ) if self.gpu_customization: diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py index 03b9c8bbf4..8a54d18be7 100644 --- a/monai/apps/auto3dseg/bundle_gen.py +++ b/monai/apps/auto3dseg/bundle_gen.py @@ -85,7 +85,8 @@ def __init__(self, template_path: PathLike): self.template_path = template_path self.data_stats_files = "" self.data_list_file = "" - self.mlflow_tracking_uri = None + self.mlflow_tracking_uri: str | None = None + self.mlflow_experiment_name: str | None = None self.output_path = "" self.name = "" self.best_metric = None @@ -139,7 +140,16 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None: the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None. """ - self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore + self.mlflow_tracking_uri = mlflow_tracking_uri + + def set_mlflow_experiment_name(self, mlflow_experiment_name: str | None) -> None: + """ + Set the experiment name for MLflow server + + Args: + mlflow_experiment_name: a string to specify the experiment name for MLflow server. + """ + self.mlflow_experiment_name = mlflow_experiment_name def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict: """ @@ -447,6 +457,7 @@ class BundleGen(AlgoGen): mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None. + mlfow_experiment_name: a string to specify the experiment name for MLflow server. .. code-block:: bash python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml" @@ -460,6 +471,7 @@ def __init__( data_stats_filename: str | None = None, data_src_cfg_name: str | None = None, mlflow_tracking_uri: str | None = None, + mlflow_experiment_name: str | None = None, ): if algos is None or isinstance(algos, (list, tuple, str)): if templates_path_or_url is None: @@ -513,6 +525,7 @@ def __init__( self.data_stats_filename = data_stats_filename self.data_src_cfg_name = data_src_cfg_name self.mlflow_tracking_uri = mlflow_tracking_uri + self.mlflow_experiment_name = mlflow_experiment_name self.history: list[dict] = [] def set_data_stats(self, data_stats_filename: str) -> None: @@ -552,10 +565,23 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri): """ self.mlflow_tracking_uri = mlflow_tracking_uri + def set_mlflow_experiment_name(self, mlflow_experiment_name): + """ + Set the experiment name for MLflow server + + Args: + mlflow_experiment_name: a string to specify the experiment name for MLflow server. + """ + self.mlflow_experiment_name = mlflow_experiment_name + def get_mlflow_tracking_uri(self): """Get the tracking URI for MLflow server""" return self.mlflow_tracking_uri + def get_mlflow_experiment_name(self): + """Get the experiment name for MLflow server""" + return self.mlflow_experiment_name + def get_history(self) -> list: """Get the history of the bundleAlgo object with their names/identifiers""" return self.history @@ -608,10 +634,12 @@ def generate( data_stats = self.get_data_stats() data_src_cfg = self.get_data_src() mlflow_tracking_uri = self.get_mlflow_tracking_uri() + mlflow_experiment_name = self.get_mlflow_experiment_name() gen_algo = deepcopy(algo) gen_algo.set_data_stats(data_stats) gen_algo.set_data_source(data_src_cfg) gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri) + gen_algo.set_mlflow_experiment_name(mlflow_experiment_name) name = f"{gen_algo.name}_{f_id}" if allow_skip: diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py index e29745e5cf..b2bea806de 100644 --- a/monai/apps/auto3dseg/ensemble_builder.py +++ b/monai/apps/auto3dseg/ensemble_builder.py @@ -464,7 +464,7 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"] ) if self.ensemble_method_name == "AlgoEnsembleBestN": - n_best = kwargs.pop("n_best", False) or 2 + n_best = kwargs.pop("n_best", 2) self.ensemble_method = AlgoEnsembleBestN(n_best=n_best) elif self.ensemble_method_name == "AlgoEnsembleBestByFold": self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold) # type: ignore diff --git a/monai/apps/auto3dseg/hpo_gen.py b/monai/apps/auto3dseg/hpo_gen.py index 688bf2b916..b755b99feb 100644 --- a/monai/apps/auto3dseg/hpo_gen.py +++ b/monai/apps/auto3dseg/hpo_gen.py @@ -193,7 +193,9 @@ def generate(self, output_folder: str = ".") -> None: self.obj_filename = os.path.join(write_path, "algo_object.pkl") if isinstance(self.algo, BundleAlgo): - self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False) + self.algo.export_to_disk( + output_folder, task_prefix + task_id, bundle_root=write_path, fill_with_datastats=False + ) else: ConfigParser.export_config_file(self.params, write_path) logger.info(write_path) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index bb10eb6b11..67ea3059cc 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -737,6 +737,7 @@ def get_dataset(self, folds: Sequence[int] | int, **dataset_params: Any) -> obje dataset_params_.update(dataset_params) class _NsplitsDataset(self.dataset_cls): # type: ignore + def _split_datalist(self, datalist: list[dict]) -> list[dict]: data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed) return select_cross_validation_folds(partitions=data, folds=folds) diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 4c92b42c08..6d0825f54a 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -34,6 +34,7 @@ class DiscardAddGuidanced(MapTransform): + def __init__( self, keys: KeysCollection, @@ -84,6 +85,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda class NormalizeLabelsInDatasetd(MapTransform): + def __init__( self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False ): @@ -121,6 +123,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda class SingleLabelSelectiond(MapTransform): + def __init__( self, keys: KeysCollection, label_names: Sequence[str] | None = None, allow_missing_keys: bool = False ): diff --git a/monai/apps/detection/metrics/coco.py b/monai/apps/detection/metrics/coco.py index 033b763be5..a856f14fb8 100644 --- a/monai/apps/detection/metrics/coco.py +++ b/monai/apps/detection/metrics/coco.py @@ -72,6 +72,7 @@ class COCOMetric: + def __init__( self, classes: Sequence[str], diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index cc9e238862..5b8f950ab3 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -164,6 +164,7 @@ def compute_matches( class ATSSMatcher(Matcher): + def __init__( self, num_candidates: int = 4, diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index e62809403e..8a10849904 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -22,6 +22,7 @@ from monai.apps.nnunet.utils import analyze_data, create_new_data_copy, create_new_dataset_json from monai.bundle import ConfigParser from monai.utils import ensure_tuple, optional_import +from monai.utils.misc import run_cmd load_pickle, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_pickle") join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") @@ -37,6 +38,7 @@ class nnUNetV2Runner: # noqa: N801 """ ``nnUNetV2Runner`` provides an interface in MONAI to use `nnU-Net` V2 library to analyze, train, and evaluate neural networks for medical image segmentation tasks. + A version of nnunetv2 higher than 2.2 is needed for this class. ``nnUNetV2Runner`` can be used in two ways: @@ -494,65 +496,64 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int fold: fold of the 5-fold cross-validation. Should be an int between 0 and 4. gpu_id: an integer to select the device to use, or a tuple/list of GPU device indices used for multi-GPU training (e.g., (0,1)). Default: 0. - from nnunetv2.run.run_training import run_training kwargs: this optional parameter allows you to specify additional arguments in - ``nnunetv2.run.run_training.run_training``. Currently supported args are - - plans_identifier: custom plans identifier. Default: "nnUNetPlans". - - pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be - used when actually training. Beta. Use with caution. Default: False. - - use_compressed_data: True to use compressed data for training. Reading compressed data is much - more CPU and (potentially) RAM intensive and should only be used if you know what you are - doing. Default: False. - - continue_training: continue training from latest checkpoint. Default: False. - - only_run_validation: True to run the validation only. Requires training to have finished. - Default: False. - - disable_checkpointing: True to disable checkpointing. Ideal for testing things out and you - don't want to flood your hard drive with checkpoints. Default: False. + ``nnunetv2.run.run_training.run_training_entry``. + + Currently supported args are: + + - p: custom plans identifier. Default: "nnUNetPlans". + - pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be + used when actually training. Beta. Use with caution. Default: False. + - use_compressed: True to use compressed data for training. Reading compressed data is much + more CPU and (potentially) RAM intensive and should only be used if you know what you are + doing. Default: False. + - c: continue training from latest checkpoint. Default: False. + - val: True to run the validation only. Requires training to have finished. + Default: False. + - disable_checkpointing: True to disable checkpointing. Ideal for testing things out and you + don't want to flood your hard drive with checkpoints. Default: False. """ if "num_gpus" in kwargs: kwargs.pop("num_gpus") logger.warning("please use gpu_id to set the GPUs to use") - if "trainer_class_name" in kwargs: - kwargs.pop("trainer_class_name") + if "tr" in kwargs: + kwargs.pop("tr") logger.warning("please specify the `trainer_class_name` in the __init__ of `nnUNetV2Runner`.") - if "export_validation_probabilities" in kwargs: - kwargs.pop("export_validation_probabilities") + if "npz" in kwargs: + kwargs.pop("npz") logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.") + cmd = self.train_single_model_command(config, fold, gpu_id, kwargs) + run_cmd(cmd, shell=True) + + def train_single_model_command(self, config, fold, gpu_id, kwargs): if isinstance(gpu_id, (tuple, list)): if len(gpu_id) > 1: gpu_ids_str = "" for _i in range(len(gpu_id)): gpu_ids_str += f"{gpu_id[_i]}," - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids_str[:-1] + device_setting = f"CUDA_VISIBLE_DEVICES={gpu_ids_str[:-1]}" else: - os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id[0]}" - else: - os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" - - from nnunetv2.run.run_training import run_training - - if isinstance(gpu_id, int) or len(gpu_id) == 1: - run_training( - dataset_name_or_id=self.dataset_name_or_id, - configuration=config, - fold=fold, - trainer_class_name=self.trainer_class_name, - export_validation_probabilities=self.export_validation_probabilities, - **kwargs, - ) + device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id[0]}" else: - run_training( - dataset_name_or_id=self.dataset_name_or_id, - configuration=config, - fold=fold, - num_gpus=len(gpu_id), - trainer_class_name=self.trainer_class_name, - export_validation_probabilities=self.export_validation_probabilities, - **kwargs, - ) + device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id}" + num_gpus = 1 if isinstance(gpu_id, int) or len(gpu_id) == 1 else len(gpu_id) + + cmd = ( + f"{device_setting} nnUNetv2_train " + + f"{self.dataset_name_or_id} {config} {fold} " + + f"-tr {self.trainer_class_name} -num_gpus {num_gpus}" + ) + if self.export_validation_probabilities: + cmd += " --npz" + for _key, _value in kwargs.items(): + if _key == "p" or _key == "pretrained_weights": + cmd += f" -{_key} {_value}" + else: + cmd += f" --{_key} {_value}" + return cmd def train( self, @@ -636,15 +637,7 @@ def train_parallel_cmd( if _config in ensure_tuple(configs): for _i in range(self.num_folds): the_device = gpu_id_for_all[_index % n_devices] # type: ignore - cmd = ( - "python -m monai.apps.nnunet nnUNetV2Runner train_single_model " - + f"--input_config '{self.input_config_or_dict}' --work_dir '{self.work_dir}' " - + f"--config '{_config}' --fold {_i} --gpu_id {the_device} " - + f"--trainer_class_name {self.trainer_class_name} " - + f"--export_validation_probabilities {self.export_validation_probabilities}" - ) - for _key, _value in kwargs.items(): - cmd += f" --{_key} {_value}" + cmd = self.train_single_model_command(_config, _i, the_device, kwargs) all_cmds[-1][the_device].append(cmd) _index += 1 return all_cmds @@ -770,7 +763,7 @@ def find_best_configuration( def predict( self, list_of_lists_or_source_folder: str | list[list[str]], - output_folder: str, + output_folder: str | None | list[str], model_training_output_dir: str, use_folds: tuple[int, ...] | str | None = None, tile_step_size: float = 0.5, @@ -824,7 +817,7 @@ def predict( """ os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" - from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor n_processes_preprocessing = ( self.default_num_processes if num_processes_preprocessing < 0 else num_processes_preprocessing @@ -832,20 +825,21 @@ def predict( n_processes_segmentation_export = ( self.default_num_processes if num_processes_segmentation_export < 0 else num_processes_segmentation_export ) - - predict_from_raw_data( - list_of_lists_or_source_folder=list_of_lists_or_source_folder, - output_folder=output_folder, - model_training_output_dir=model_training_output_dir, - use_folds=use_folds, + predictor = nnUNetPredictor( tile_step_size=tile_step_size, use_gaussian=use_gaussian, use_mirroring=use_mirroring, - perform_everything_on_gpu=perform_everything_on_gpu, + perform_everything_on_device=perform_everything_on_gpu, verbose=verbose, + ) + predictor.initialize_from_trained_model_folder( + model_training_output_dir=model_training_output_dir, use_folds=use_folds, checkpoint_name=checkpoint_name + ) + predictor.predict_from_files( + list_of_lists_or_source_folder=list_of_lists_or_source_folder, + output_folder_or_list_of_truncated_output_files=output_folder, save_probabilities=save_probabilities, overwrite=overwrite, - checkpoint_name=checkpoint_name, num_processes_preprocessing=n_processes_preprocessing, num_processes_segmentation_export=n_processes_segmentation_export, folder_with_segs_from_prev_stage=folder_with_segs_from_prev_stage, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index ef9c535725..99e94f89c0 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -162,7 +162,7 @@ def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor: pred = label(pred)[0] if self.remove_small_objects is not None: pred = self.remove_small_objects(pred) - pred[pred > 0] = 1 # type: ignore + pred[pred > 0] = 1 return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0] @@ -338,7 +338,7 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N instance_border = instance_border >= self.threshold # uncertain area marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground - marker[marker < 0] = 0 # type: ignore + marker[marker < 0] = 0 marker = self.postprocess_fn(marker) marker = convert_to_numpy(marker) @@ -634,7 +634,7 @@ def __call__( # type: ignore seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0] - inst_type = type_map_crop[seg_map_crop] # type: ignore + inst_type = type_map_crop[seg_map_crop] type_list, type_pixels = unique(inst_type, return_counts=True) type_list = list(zip(type_list, type_pixels)) type_list = sorted(type_list, key=lambda x: x[1], reverse=True) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 442dbabba0..db541923b5 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -30,7 +30,7 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import -gdown, has_gdown = optional_import("gdown", "4.6.3") +gdown, has_gdown = optional_import("gdown", "4.7.3") if TYPE_CHECKING: from tqdm import tqdm diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 56419da4cb..37f3faea21 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe torch.set_grad_enabled(False) ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D) + ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") @@ -472,7 +472,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe if isinstance(ndas_label, (MetaTensor, torch.Tensor)): unique_label = unique_label.data.cpu().numpy() - unique_label = unique_label.astype(np.int8).tolist() + unique_label = unique_label.astype(np.int16).tolist() label_substats = [] # each element is one label pixel_sum = 0 diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index c6da0a73de..e5122bf3de 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable): - ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``: - ``"default"``: returns ``component(**kwargs)`` - - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)`` - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` Other fields in the config content are input arguments to the python module. @@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object: mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) args = self.resolve_args() args.update(kwargs) - try: - return instantiate(modname, mode, **args) - except Exception as e: - raise RuntimeError(f"Failed to instantiate {self}") from e + return instantiate(modname, mode, **args) class ConfigExpression(ConfigItem): diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2565a3cf64..598d938cbd 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -778,10 +778,19 @@ def run( https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. Default to None. tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. - if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, - if other string, treat it as file path to load the tracking settings. - if `dict`, treat it as tracking settings. - will patch the target config content with `tracking handlers` and the top-level items of `configs`. + If "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings where a set of + common parameters shown below will be added and can be passed through the `override` parameter of this method. + + - ``"output_dir"``: the path to save mlflow tracking outputs locally, default to "/eval". + - ``"tracking_uri"``: uri to save mlflow tracking outputs, default to "/output_dir/mlruns". + - ``"experiment_name"``: experiment name for this run, default to "monai_experiment". + - ``"run_name"``: the name of current run. + - ``"save_execute_config"``: whether to save the executed config files. It can be `False`, `/path/to/artifacts` + or `True`. If set to `True`, will save to the default path "/eval". Default to `True`. + + If other string, treat it as file path to load the tracking settings. + If `dict`, treat it as tracking settings. + Will patch the target config content with `tracking handlers` and the top-level items of `configs`. for detailed usage examples, please check the tutorial: https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. args_file: a JSON or YAML file to provide default values for `run_id`, `meta_file`, diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index b187159c89..a0f39d236f 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -113,7 +113,7 @@ "experiment_name": "monai_experiment", "run_name": None, # may fill it at runtime - "execute_config": None, + "save_execute_config": True, "is_not_rank0": ( "$torch.distributed.is_available() \ and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" @@ -125,7 +125,7 @@ "tracking_uri": "@tracking_uri", "experiment_name": "@experiment_name", "run_name": "@run_name", - "artifacts": "@execute_config", + "artifacts": "@save_execute_config", "iteration_log": True, "epoch_log": True, "tag_name": "train_loss", @@ -148,7 +148,7 @@ "tracking_uri": "@tracking_uri", "experiment_name": "@experiment_name", "run_name": "@run_name", - "artifacts": "@execute_config", + "artifacts": "@save_execute_config", "iteration_log": False, "close_on_complete": True, }, diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index da3aa30141..471088994b 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -11,6 +11,7 @@ from __future__ import annotations +import json import os import sys import time @@ -24,6 +25,7 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY +from monai.config import PathLike from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -46,6 +48,10 @@ class BundleWorkflow(ABC): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + properties_path: the path to the JSON file of properties. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. """ @@ -59,21 +65,62 @@ class BundleWorkflow(ABC): new_name="workflow_type", msg_suffix="please use `workflow_type` instead.", ) - def __init__(self, workflow_type: str | None = None, workflow: str | None = None): + def __init__( + self, + workflow_type: str | None = None, + workflow: str | None = None, + properties_path: PathLike | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + ): + if logging_file is not None: + if not os.path.isfile(logging_file): + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) + + if meta_file is not None: + if isinstance(meta_file, str) and not os.path.isfile(meta_file): + logger.error( + f"Cannot find the metadata config file: {meta_file}. " + "Please see: https://docs.monai.io/en/stable/mb_specification.html" + ) + meta_file = None + if isinstance(meta_file, list): + for f in meta_file: + if not os.path.isfile(f): + logger.error( + f"Cannot find the metadata config file: {f}. " + "Please see: https://docs.monai.io/en/stable/mb_specification.html" + ) + meta_file = None + workflow_type = workflow if workflow is not None else workflow_type - if workflow_type is None: + if workflow_type is None and properties_path is None: self.properties = copy(MetaProperties) self.workflow_type = None + self.meta_file = meta_file return - if workflow_type.lower() in self.supported_train_type: + if properties_path is not None: + properties_path = Path(properties_path) + if not properties_path.is_file(): + raise ValueError(f"Property file {properties_path} does not exist.") + with open(properties_path) as json_file: + self.properties = json.load(json_file) + self.workflow_type = None + self.meta_file = meta_file + return + if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr] self.properties = {**TrainProperties, **MetaProperties} self.workflow_type = "train" - elif workflow_type.lower() in self.supported_infer_type: + elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr] self.properties = {**InferProperties, **MetaProperties} self.workflow_type = "infer" else: raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + self.meta_file = meta_file + @abstractmethod def initialize(self, *args: Any, **kwargs: Any) -> Any: """ @@ -142,6 +189,13 @@ def get_workflow_type(self): """ return self.workflow_type + def get_meta_file(self): + """ + Get the meta file. + + """ + return self.meta_file + def add_property(self, name: str, required: str, desc: str | None = None) -> None: """ Besides the default predefined properties, some 3rd party applications may need the bundle @@ -206,6 +260,7 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + properties_path: the path to the JSON file of properties. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -230,28 +285,30 @@ def __init__( tracking: str | dict | None = None, workflow_type: str | None = None, workflow: str | None = None, + properties_path: PathLike | None = None, **override: Any, ) -> None: workflow_type = workflow if workflow is not None else workflow_type - super().__init__(workflow_type=workflow_type) if config_file is not None: _config_files = ensure_tuple(config_file) - self.config_root_path = Path(_config_files[0]).parent + config_root_path = Path(_config_files[0]).parent for _config_file in _config_files: _config_file = Path(_config_file) - if _config_file.parent != self.config_root_path: + if _config_file.parent != config_root_path: logger.warn( - f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are" - f"not specified, {self.config_root_path} will be used as the default config root directory." + f"Not all config files are in {config_root_path}. If logging_file and meta_file are" + f"not specified, {config_root_path} will be used as the default config root directory." ) if not _config_file.is_file(): raise FileNotFoundError(f"Cannot find the config file: {_config_file}.") else: - self.config_root_path = Path("configs") - + config_root_path = Path("configs") + meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file + super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path) + self.config_root_path = config_root_path logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file if logging_file is not None: - if not os.path.exists(logging_file): + if not os.path.isfile(logging_file): if logging_file == str(self.config_root_path / "logging.conf"): logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") else: @@ -262,14 +319,8 @@ def __init__( self.parser = ConfigParser() self.parser.read_config(f=config_file) - meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file - if isinstance(meta_file, str) and not os.path.exists(meta_file): - logger.error( - f"Cannot find the metadata config file: {meta_file}. " - "Please see: https://docs.monai.io/en/stable/mb_specification.html" - ) - else: - self.parser.read_meta(f=meta_file) + if self.meta_file is not None: + self.parser.read_meta(f=self.meta_file) # the rest key-values in the _args are to override config content self.parser.update(pairs=override) @@ -455,13 +506,19 @@ def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None: parser[k] = v # save the executed config into file default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" - filepath = parser.get("execute_config", None) - if filepath is None: - if "output_dir" not in parser: - # if no "output_dir" in the bundle config, default to "/eval" - parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'" - # experiment management tools can refer to this config item to track the config info - parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" - filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) - Path(filepath).parent.mkdir(parents=True, exist_ok=True) - parser.export_config_file(parser.get(), filepath) + # Users can set the `save_execute_config` to `False`, `/path/to/artifacts` or `True`. + # If set to False, nothing will be recorded. If set to True, the default path will be logged. + # If set to a file path, the given path will be logged. + filepath = parser.get("save_execute_config", True) + if filepath: + if isinstance(filepath, bool): + if "output_dir" not in parser: + # if no "output_dir" in the bundle config, default to "/eval" + parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'" + # experiment management tools can refer to this config item to track the config info + parser["save_execute_config"] = parser["output_dir"] + f" + '/{default_name}'" + filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + parser.export_config_file(parser.get(), filepath) + else: + parser["save_execute_config"] = None diff --git a/monai/data/dataset.py b/monai/data/dataset.py index eba850225d..79e066303e 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -427,7 +427,7 @@ def _transform(self, index: int): class CacheNTransDataset(PersistentDataset): """ - Extension of `PersistentDataset`, tt can also cache the result of first N transforms, no matter it's random or not. + Extension of `PersistentDataset`, it can also cache the result of first N transforms, no matter it's random or not. """ @@ -1275,6 +1275,7 @@ def __len__(self) -> int: return min(len(dataset) for dataset in self.data) def _transform(self, index: int): + def to_list(x): return list(x) if isinstance(x, (tuple, list)) else [x] diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py index be3bcf5bd5..9ff23ebeff 100644 --- a/monai/data/video_dataset.py +++ b/monai/data/video_dataset.py @@ -173,15 +173,15 @@ def get_available_codecs() -> dict[str, str]: all_codecs = {"mp4v": ".mp4", "X264": ".avi", "H264": ".mp4", "MP42": ".mp4", "MJPG": ".mjpeg", "DIVX": ".avi"} codecs = {} with SuppressStderr(): - writer = cv2.VideoWriter() with tempfile.TemporaryDirectory() as tmp_dir: for codec, ext in all_codecs.items(): + writer = cv2.VideoWriter() fname = os.path.join(tmp_dir, f"test{ext}") - fourcc = cv2.VideoWriter_fourcc(*codec) + fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined] noviderr = writer.open(fname, fourcc, 1, (10, 10)) if noviderr: codecs[codec] = ext - writer.release() + writer.release() return codecs def get_num_frames(self) -> int: diff --git a/monai/fl/client/client_algo.py b/monai/fl/client/client_algo.py index 25a88a9e66..3dc9f5785d 100644 --- a/monai/fl/client/client_algo.py +++ b/monai/fl/client/client_algo.py @@ -57,6 +57,7 @@ def abort(self, extra: dict | None = None) -> None: class ClientAlgoStats(BaseClient): + def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: """ Get summary statistics about the local data. diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 0382b8cb64..021154d705 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -157,6 +157,7 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] @deprecated(since="1.2", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") class IgniteMetric(IgniteMetricHandler): + def __init__( self, metric_fn: CumulativeIterationMetric | None = None, diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 92898c81ca..e937b53fa4 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .adversarial_loss import PatchAdversarialLoss +from .barlow_twins import BarlowTwinsLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss, DiffusionLoss @@ -40,5 +41,6 @@ from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss +from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/barlow_twins.py b/monai/losses/barlow_twins.py new file mode 100644 index 0000000000..a61acca66e --- /dev/null +++ b/monai/losses/barlow_twins.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch.nn.modules.loss import _Loss + + +class BarlowTwinsLoss(_Loss): + """ + The Barlow Twins cost function takes the representations extracted by a neural network from two + distorted views and seeks to make the cross-correlation matrix of the two representations tend + towards identity. This encourages the neural network to learn similar representations with the least + amount of redundancy. This cost function can be used in particular in multimodal learning to work on + representations from two modalities. The most common use case is for unsupervised learning, where data + augmentations are used to generate 2 distorted views of the same sample to force the encoder to + extract useful features for downstream tasks. + + Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf) + + Adapted from: + https://github.com/facebookresearch/barlowtwins + + """ + + def __init__(self, lambd: float = 5e-3) -> None: + """ + Args: + lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3. + + Raises: + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + ValueError: When batch size is less than or equal to 1 + + """ + super().__init__() + self.lambd = lambd + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise ValueError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if target.size(0) <= 1: + raise ValueError( + f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}" + ) + + lambd_tensor = torch.as_tensor(self.lambd).to(input.device) + batch_size = input.shape[0] + + # normalize input and target + input_norm = (input - input.mean(0)) / input.std(0).add(1e-6) + target_norm = (target - target.mean(0)) / target.std(0).add(1e-6) + + # cross-correlation matrix + c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF + + # loss + c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF + c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor + + return c_diff.sum() diff --git a/monai/losses/dice.py b/monai/losses/dice.py index b3c0f57c6e..f1c357d31f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. + + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) @@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) + if self.to_onehot_y: n_pred_ch = input.shape[1] if n_pred_ch == 1: @@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target (torch.Tensor): the shape should be BNH[WD] or B1H[WD]. Raises: - ValueError: When the input and target tensors have different numbers of dimensions, or the target - channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. Returns: torch.Tensor: value of the loss. """ if input.dim() != target.dim(): raise ValueError( - f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions." + "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " + f"got shape {input.shape} and {target.shape}." ) gdl_loss = self.generalized_dice(input, target) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 98c1a071b6..28d1c0cdc9 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -234,9 +234,8 @@ def sigmoid_focal_loss( """ # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') - # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 - max_val = (-input).clamp(min=0) - loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 + loss: torch.Tensor = input - input * target - F.logsigmoid(input) # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 2207de5e64..a8ae90993a 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -29,7 +29,7 @@ class PercetualNetworkType(StrEnum): squeeze = "squeeze" radimagenet_resnet50 = "radimagenet_resnet50" medicalnet_resnet10_23datasets = "medicalnet_resnet10_23datasets" - medical_resnet50_23datasets = "medical_resnet50_23datasets" + medicalnet_resnet50_23datasets = "medicalnet_resnet50_23datasets" resnet50 = "resnet50" @@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module): The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss. + MedicalNet networks are only compatible with 3D inputs and support channel-wise loss. Args: spatial_dims: number of spatial dimensions. @@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module): pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50". Defaults to `None`. + channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. + Defaults to ``False``. """ def __init__( @@ -74,6 +77,7 @@ def __init__( pretrained: bool = True, pretrained_path: str | None = None, pretrained_state_dict_key: str | None = None, + channel_wise: bool = False, ): super().__init__() @@ -86,6 +90,9 @@ def __init__( "Argument is_fake_3d must be set to False." ) + if channel_wise and "medicalnet_" not in network_type: + raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") + if network_type.lower() not in list(PercetualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" @@ -102,7 +109,9 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module if spatial_dims == 3 and is_fake_3d is False: - self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) + self.perceptual_function = MedicalNetPerceptualSimilarity( + net=network_type, verbose=False, channel_wise=channel_wise + ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) elif network_type == "resnet50": @@ -116,6 +125,7 @@ def __init__( self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio + self.channel_wise = channel_wise def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor: """ @@ -172,7 +182,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D and real 3D cases loss = self.perceptual_function(input, target) - return torch.mean(loss) + if self.channel_wise: + loss = torch.mean(loss.squeeze(), dim=0) + else: + loss = torch.mean(loss) + + return loss class MedicalNetPerceptualSimilarity(nn.Module): @@ -185,14 +200,20 @@ class MedicalNetPerceptualSimilarity(nn.Module): net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. verbose: if false, mute messages from torch Hub load function. + channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. + Defaults to ``False``. """ - def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: + def __init__( + self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) self.eval() + self.channel_wise = channel_wise + for param in self.parameters(): param.requires_grad = False @@ -206,20 +227,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Args: input: 3D input tensor with shape BCDHW. target: 3D target tensor with shape BCDHW. + """ input = medicalnet_intensity_normalisation(input) target = medicalnet_intensity_normalisation(target) # Get model outputs - outs_input = self.model.forward(input) - outs_target = self.model.forward(target) + feats_per_ch = 0 + for ch_idx in range(input.shape[1]): + input_channel = input[:, ch_idx, ...].unsqueeze(1) + target_channel = target[:, ch_idx, ...].unsqueeze(1) + + if ch_idx == 0: + outs_input = self.model.forward(input_channel) + outs_target = self.model.forward(target_channel) + feats_per_ch = outs_input.shape[1] + else: + outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1) + outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1) # Normalise through the channels feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results: torch.Tensor = (feats_input - feats_target) ** 2 - results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + feats_diff: torch.Tensor = (feats_input - feats_target) ** 2 + if self.channel_wise: + results = torch.zeros( + feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4] + ) + for i in range(input.shape[1]): + l_idx = i * feats_per_ch + r_idx = (i + 1) * feats_per_ch + results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) + else: + results = feats_diff.sum(dim=1, keepdim=True) + + results = spatial_average_3d(results, keepdim=True) return results diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py new file mode 100644 index 0000000000..ebf25613a6 --- /dev/null +++ b/monai/losses/sure_loss.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + First compute the difference in the complex domain, + then get the absolute value and take the mse + + Args: + x, y - B, 2, H, W real valued tensors representing complex numbers + or B,1,H,W complex valued tensors + Returns: + l2_loss - scalar + """ + if not x.is_complex(): + x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) + if not y.is_complex(): + y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + + diff = torch.abs(x - y) + return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") + + +def sure_loss_function( + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + eps: Optional[float] = -1.0, + perturb_noise: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, +) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. For complex input, the shape is + (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + + y_ref (torch.Tensor, optional): The reference output tensor of shape + (B, C, H, W) used to compute the divergence. Defaults to None. For + complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, + the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Set to -1 to set it + automatically estimated based on y_pseudo_gtk + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). + Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + complex_input(bool, optional): Whether the input is complex or not. + Defaults to False. + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # perturb input + if perturb_noise is None: + perturb_noise = torch.randn_like(x) + if eps == -1.0: + eps = float(torch.abs(y_pseudo_gt.max())) / 1000 + # get y_ref if not provided + if y_ref is None: + y_ref = operator(x) + + # get perturbed output + x_perturbed = x + eps * perturb_noise + y_perturbed = operator(x_perturbed) + # divergence + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore + # l2 loss between y_ref, y_pseudo_gt + if complex_input: + l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) + else: + # real input + l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") + + # sure loss + sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) + return sure_loss + + +class SURELoss(_Loss): + """ + Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. + + This is a differentiable loss function that can be used to train/guide an + operator (e.g. neural network), where the pseudo ground truth is available + but the reference ground truth is not. For example, in the MRI + reconstruction, the pseudo ground truth is the zero-filled reconstruction + and the reference ground truth is the fully sampled reconstruction. Often, + the reference ground truth is not available due to the lack of fully sampled + data. + + The original SURE loss is proposed in [1]. The SURE loss used for guiding + the diffusion model based MRI reconstruction is proposed in [2]. + + Reference + + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. + (https://arxiv.org/pdf/2310.01799.pdf) + """ + + def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: + """ + Args: + perturb_noise (torch.Tensor, optional): The noise vector of shape + (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + """ + super().__init__() + self.perturb_noise = perturb_noise + self.eps = eps + + def forward( + self, + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, + ) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka + C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex + input, the shape is (B, 2, H, W) aka C=2 real. For real input, the + shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the + same shape as y_pseudo_gt + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + + # check inputs shapes + if x.dim() != 4: + raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") + if y_pseudo_gt.dim() != 4: + raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.") + if y_ref is not None and y_ref.dim() != 4: + raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") + if x.shape != y_pseudo_gt.shape: + raise ValueError( + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, " + f"y_pseudo_gt shape {y_pseudo_gt.shape}." + ) + if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: + raise ValueError( + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, " + f"y_ref shape {y_ref.shape}." + ) + + # compute loss + loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) + + return loss diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index 61e4525662..bb9371c8bf 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -22,6 +22,7 @@ class FBetaScore(CumulativeIterationMetric): + def __init__( self, beta: float = 1.0, diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index a6dc1a49a2..249b2dc951 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -37,6 +37,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + def __str__(self): + return self.__class__.__name__ + class IterationMetric(Metric): """ diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 9d29654ee3..4c8b8aa71b 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -303,7 +303,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor if self.spatial_dims == 3 and dims != 5: raise ValueError( - f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f"y_pred should have 5 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" f" spatial dimensions, got {dims}." ) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index 12afab3464..801b49de8b 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -245,6 +245,7 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): + def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, dropout: tuple | str | float | None = None ): diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 11808eabf7..6e0efc8588 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -72,6 +72,7 @@ def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> class ResidualBlock(nn.Module): + def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int ) -> None: @@ -95,6 +96,7 @@ def forward(self, x) -> torch.Tensor: class LocalNetResidualBlock(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: super().__init__() if in_channels != out_channels: diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 7d56045814..91bd73ebbb 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -120,9 +120,7 @@ def __init__( for in_size, pa_size in zip(img_size, patch_size): grid_size.append(in_size // pa_size) - with torch.no_grad(): - pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) - self.position_embeddings.data.copy_(pos_embeddings.float()) + self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 138149cac6..21586e56da 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -23,6 +23,7 @@ # From PyTorch internals def _ntuple(n): + def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) @@ -45,7 +46,7 @@ def build_sincos_position_embedding( temperature (float): The temperature for the sin-cos position embedding. Returns: - pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter. + pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter. """ if spatial_dims == 2: diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index dee9966919..50fd39a70b 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -17,8 +17,8 @@ import torch.nn as nn from monai.networks.layers.factories import Conv, Pad, Pool -from monai.networks.utils import icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option +from monai.networks.utils import CastTempType, icnr_init, pixelshuffle +from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -50,6 +50,7 @@ def __init__( size: tuple[int] | int | None = None, mode: UpsampleMode | str = UpsampleMode.DECONV, pre_conv: nn.Module | str | None = "default", + post_conv: nn.Module | None = None, interp_mode: str = InterpolateMode.LINEAR, align_corners: bool | None = True, bias: bool = True, @@ -71,6 +72,7 @@ def __init__( pre_conv: a conv block applied before upsampling. Defaults to "default". When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when Only used in the "nontrainable" or "pixelshuffle" mode. + post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable" mode. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} Only used in the "nontrainable" mode. If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. @@ -154,15 +156,25 @@ def __init__( linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] if interp_mode in linear_mode: # choose mode based on dimensions interp_mode = linear_mode[spatial_dims - 1] - self.add_module( - "upsample_non_trainable", - nn.Upsample( - size=size, - scale_factor=None if size else scale_factor_, - mode=interp_mode.value, - align_corners=align_corners, - ), + + upsample = nn.Upsample( + size=size, + scale_factor=None if size else scale_factor_, + mode=interp_mode.value, + align_corners=align_corners, ) + + # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1 + if pytorch_after(major=2, minor=1): + self.add_module("upsample_non_trainable", upsample) + else: + self.add_module( + "upsample_non_trainable", + CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample), + ) + if post_conv: + self.add_module("postconv", post_conv) elif up_mode == UpsampleMode.PIXELSHUFFLE: self.add_module( "pixelshuffle", diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index bd3e3af3af..48c10270b1 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,9 +11,10 @@ from __future__ import annotations +from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath -from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args +from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter from .gmm import GaussianMixtureModel from .simplelayers import ( diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py new file mode 100644 index 0000000000..93a45930d7 --- /dev/null +++ b/monai/networks/layers/conjugate_gradient.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn + + +def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2: sum(x1.*x2) + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + + +def _zdot_single(x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = _zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + + +class ConjugateGradient(nn.Module): + """ + Congugate Gradient (CG) solver for linear systems Ax = y. + + For linear_op that is positive definite and self-adjoint, CG is + guaranteed to converge CG is often used to solve linear systems of the form + Ax = y, where A is too large to store explicitly, but can be computed via a + linear operator. + + As a result, here we won't set A explicitly as a matrix, but rather as a + linear operator. For example, A could be a FFT/IFFT operation + """ + + def __init__(self, linear_op: Callable, num_iter: int): + """ + Args: + linear_op: Linear operator + num_iter: Number of iterations to run CG + """ + super().__init__() + + self.linear_op = linear_op + self.num_iter = num_iter + + def update( + self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + perform one iteration of the CG method. It takes the current solution x, + the current search direction p, the current residual r, and the old + residual norm rsold as inputs. Then it computes the new solution, search + direction, residual, and residual norm, and returns them. + """ + + dy = self.linear_op(p) + p_dot_dy = _zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = _zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + return x, p, r, rsold + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + run conjugate gradient for num_iter iterations to solve Ax = y + + Args: + x: tensor (real or complex); Initial guess for linear system Ax = y. + The size of x should be applicable to the linear operator. For + example, if the linear operator is FFT, then x is HCHW; if the + linear operator is a matrix multiplication, then x is a vector + + y: tensor (real or complex); Measurement. Same size as x + + Returns: + x: Solution to Ax = y + """ + # Compute residual + r = y - self.linear_op(x) + rsold = _zdot_single(r) + p = r + + # Update + for _i in range(self.num_iter): + x, p, r, rsold = self.update(x, p, r, rsold) + if rsold < 1e-10: + break + return x diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py index 94d619bb7a..6ebe66832f 100644 --- a/monai/networks/layers/gmm.py +++ b/monai/networks/layers/gmm.py @@ -78,6 +78,7 @@ def apply(self, features): class _ApplyFunc(torch.autograd.Function): + @staticmethod def forward(ctx, params, features, compiled_extension): return compiled_extension.apply(params, features) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index a1122ceaa2..4ac621967f 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -552,6 +552,7 @@ def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor: class GaussianFilter(nn.Module): + def __init__( self, spatial_dims: int, @@ -607,6 +608,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LLTMFunction(Function): + @staticmethod def forward(ctx, input, weights, bias, old_h, old_cell): outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 53f35e63f2..2d39dfdbc1 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -33,6 +33,7 @@ class _GridPull(torch.autograd.Function): + @staticmethod def forward(ctx, input, grid, interpolation, bound, extrapolate): opt = (bound, interpolation, extrapolate) @@ -132,6 +133,7 @@ def grid_pull( class _GridPush(torch.autograd.Function): + @staticmethod def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): opt = (bound, interpolation, extrapolate) @@ -236,6 +238,7 @@ def grid_push( class _GridCount(torch.autograd.Function): + @staticmethod def forward(ctx, grid, shape, interpolation, bound, extrapolate): opt = (bound, interpolation, extrapolate) @@ -335,6 +338,7 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze class _GridGrad(torch.autograd.Function): + @staticmethod def forward(ctx, input, grid, interpolation, bound, extrapolate): opt = (bound, interpolation, extrapolate) @@ -433,6 +437,7 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b class AffineTransform(nn.Module): + def __init__( self, spatial_size: Sequence[int] | int | None = None, diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a7ce16ad64..9101ab862e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -63,6 +63,8 @@ ResNet, ResNetBlock, ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index b0ad1eabbd..5e280d7f24 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -87,6 +87,7 @@ def forward(self, x): class Projection(nn.Sequential): + def __init__(self, spatial_dims: int, num_input_features: int, num_output_features: int): super().__init__() @@ -100,6 +101,7 @@ def __init__(self, spatial_dims: int, num_input_features: int, num_output_featur class DenseBlock(nn.Sequential): + def __init__( self, spatial_dims: int, @@ -118,6 +120,7 @@ def __init__( class UpTransition(nn.Sequential): + def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): @@ -143,6 +146,7 @@ def __init__( class Final(nn.Sequential): + def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): @@ -178,6 +182,7 @@ def __init__( class Pseudo3DLayer(nn.Module): + def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, bn_size: int, dropout_prob: float): super().__init__() # 1x1x1 @@ -244,6 +249,7 @@ def forward(self, x): class PSP(nn.Module): + def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_mode: str = "transpose"): super().__init__() self.up_modules = nn.ModuleList() diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 362d63d636..5689cf1071 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -23,6 +23,7 @@ class ConvBlock(nn.Module): + def __init__( self, spatial_dims: int, @@ -67,6 +68,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class UpConv(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0): super().__init__() self.up = Convolution( @@ -88,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AttentionBlock(nn.Module): + def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0): super().__init__() self.W_g = nn.Sequential( @@ -145,6 +148,7 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: class AttentionLayer(nn.Module): + def __init__( self, spatial_dims: int, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index f7ae77f056..372e704d53 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -11,7 +11,6 @@ from __future__ import annotations -import math from collections.abc import Sequence from typing import List @@ -19,92 +18,44 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution - -# To install xformers, use pip install xformers==0.0.16rc401 +from monai.networks.blocks import Convolution, Upsample +from monai.networks.blocks.selfattention import SABlock from monai.utils import ensure_tuple_rep, optional_import -xformers, has_xformers = optional_import("xformers") +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") __all__ = ["AutoencoderKL"] -class _Upsample(nn.Module): +class AsymmetricPad(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Convolution-based upsampling layer. + Pad the input tensor asymmetrically along every spatial dimension. Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: number of input channels to the layer. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ - def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + def __init__(self, spatial_dims: int) -> None: super().__init__() - if use_convtranspose: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=2, - kernel_size=3, - padding=1, - conv_only=True, - is_transposed=True, - ) - else: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.use_convtranspose = use_convtranspose + self.pad = (0, 1) * spatial_dims def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_convtranspose: - conv: torch.Tensor = self.conv(x) - return conv - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - x = self.conv(x) + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) return x -class _Downsample(nn.Module): +class AEKLDownsample(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Convolution-based downsampling layer. Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. """ def __init__(self, spatial_dims: int, in_channels: int) -> None: super().__init__() - self.pad = (0, 1) * spatial_dims + self.pad = AsymmetricPad(spatial_dims=spatial_dims) self.conv = Convolution( spatial_dims=spatial_dims, @@ -117,17 +68,13 @@ def __init__(self, spatial_dims: int, in_channels: int) -> None: ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.pad(x) x = self.conv(x) return x -class _ResBlock(nn.Module): +class AEKLResBlock(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a residual connection between input and output. @@ -197,22 +144,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class _AttentionBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 +class AttentionBlock(nn.Module): + """Perform spatial self-attention on the input tensor. - Attention block. + The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels. Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + num_channels: number of input channels. Must be divisible by num_head_channels. + num_head_channels: number of channels per head. """ def __init__( @@ -222,102 +162,41 @@ def __init__( num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, - use_flash_attention: bool = False, ) -> None: super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) + self.spatial_dims = spatial_dims self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + # check num_head_channels is divisible by num_channels + if num_head_channels is not None and num_channels % num_head_channels != 0: + raise ValueError("num_channels must be divisible by num_head_channels") + num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x + self.attn = SABlock(hidden_size=num_channels, num_heads=num_heads, qkv_bias=True) - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor): residual = x - batch = channel = height = width = depth = -1 + if self.spatial_dims == 1: + h = x.shape[2] + rearrange_input = Rearrange("b c h -> b h c") + rearrange_output = Rearrange("b h c -> b c h", h=h) if self.spatial_dims == 2: - batch, channel, height, width = x.shape + h, w = x.shape[2], x.shape[3] + rearrange_input = Rearrange("b c h w -> b (h w) c") + rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape + h, w, d = x.shape[2], x.shape[3], x.shape[4] + rearrange_input = Rearrange("b c h w d -> b (h w d) c") + rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - # norm x = self.norm(x) + x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual + x = self.attn(x) + x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x + residual + return x class Encoder(nn.Module): @@ -334,7 +213,6 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -348,7 +226,6 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -383,7 +260,7 @@ def __init__( for _ in range(self.num_res_blocks[i]): blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=input_channel, norm_num_groups=norm_num_groups, @@ -394,22 +271,20 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) if not is_final_block: - blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) - + blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) # Non-local attention block if with_nonlocal_attn is True: blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=channels[-1], norm_num_groups=norm_num_groups, @@ -419,16 +294,15 @@ def __init__( ) blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=channels[-1], norm_num_groups=norm_num_groups, @@ -472,7 +346,6 @@ class Decoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -487,7 +360,6 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: super().__init__() @@ -520,7 +392,7 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -529,16 +401,15 @@ def __init__( ) ) blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -557,7 +428,7 @@ def __init__( for _ in range(reversed_num_res_blocks[i]): blocks.append( - _ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=block_in_ch, norm_num_groups=norm_num_groups, @@ -569,19 +440,43 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) if not is_final_block: - blocks.append( - _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) - ) + if use_convtranspose: + blocks.append( + Upsample( + spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch + ) + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=block_in_ch, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + blocks.append( + Upsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=block_in_ch, + out_channels=block_in_ch, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + ) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -622,7 +517,6 @@ class AutoencoderKL(nn.Module): norm_eps: epsilon for the normalization. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -640,7 +534,6 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, use_checkpoint: bool = False, use_convtranspose: bool = False, ) -> None: @@ -662,11 +555,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, @@ -677,7 +565,6 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, - use_flash_attention=use_flash_attention, ) self.decoder = Decoder( spatial_dims=spatial_dims, @@ -689,7 +576,6 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, - use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, ) self.quant_conv_mu = Convolution( @@ -805,3 +691,68 @@ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: image = self.decode(z) return image + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old AutoencoderKL model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.to_q.weight"], + old_state_dict[f"{block}.to_k.weight"], + old_state_dict[f"{block}.to_v.weight"], + ], + dim=0, + ) + new_state_dict[f"{block}.attn.qkv.bias"] = torch.concat( + [ + old_state_dict[f"{block}.to_q.bias"], + old_state_dict[f"{block}.to_k.bias"], + old_state_dict[f"{block}.to_v.bias"], + ], + dim=0, + ) + # old version did not have a projection so set these to the identity + new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( + new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] + ) + new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros( + new_state_dict[f"{block}.attn.out_proj.bias"].shape + ) + + # fix the upsample conv blocks which were renamed postconv + for k in new_state_dict: + if "postconv" in k: + old_name = k.replace("postconv", "conv") + new_state_dict[k] = old_state_dict[old_name] + self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 7fc57edc42..b9970d4113 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -176,6 +176,7 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): class BasicUNet(nn.Module): + def __init__( self, spatial_dims: int = 3, diff --git a/monai/networks/nets/basic_unetplusplus.py b/monai/networks/nets/basic_unetplusplus.py index 28d4b4668a..f7ae768513 100644 --- a/monai/networks/nets/basic_unetplusplus.py +++ b/monai/networks/nets/basic_unetplusplus.py @@ -24,6 +24,7 @@ class BasicUNetPlusPlus(nn.Module): + def __init__( self, spatial_dims: int = 3, diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 2100272d91..5ccb429c91 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -42,6 +42,7 @@ class _DenseLayer(nn.Module): + def __init__( self, spatial_dims: int, @@ -88,6 +89,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class _DenseBlock(nn.Sequential): + def __init__( self, spatial_dims: int, @@ -119,6 +121,7 @@ def __init__( class _Transition(nn.Sequential): + def __init__( self, spatial_dims: int, diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 6e3420d136..129e0925d3 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -73,6 +73,7 @@ def _dfs(node, paths): class _IdentityWithRAMCost(nn.Identity): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ram_cost = 0 @@ -105,6 +106,7 @@ def __init__( class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): + def __init__( self, in_channel: int, @@ -122,6 +124,7 @@ def __init__( class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): + def __init__( self, in_channel: int, @@ -138,6 +141,7 @@ def __init__( class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): + def __init__( self, in_channel: int, diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index d89ab53ea2..4e6c327b23 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -73,6 +73,7 @@ class MBConvBlock(nn.Module): + def __init__( self, spatial_dims: int, @@ -227,6 +228,7 @@ def set_swish(self, memory_efficient: bool = True) -> None: class EfficientNet(nn.Module): + def __init__( self, blocks_args_str: list[str], @@ -472,6 +474,7 @@ def _initialize_weights(self) -> None: class EfficientNetBN(EfficientNet): + def __init__( self, model_name: str, @@ -558,6 +561,7 @@ def __init__( class EfficientNetBNFeatures(EfficientNet): + def __init__( self, model_name: str, diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index ac2124b5f9..c27b0fc17b 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -24,6 +24,7 @@ from monai.networks.layers.utils import get_act_layer from monai.networks.nets import EfficientNetEncoder from monai.networks.nets.basic_unet import UpCat +from monai.networks.nets.resnet import ResNetEncoder from monai.utils import InterpolateMode, optional_import __all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"] @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str): FLEXUNET_BACKBONE = FlexUNetEncoderRegister() FLEXUNET_BACKBONE.register_class(EfficientNetEncoder) +FLEXUNET_BACKBONE.register_class(ResNetEncoder) class UNetDecoder(nn.Module): @@ -238,7 +240,7 @@ def __init__( ) -> None: """ A flexible implement of UNet, in which the backbone/encoder can be replaced with - any efficient network. Currently the input must have a 2 or 3 spatial dimension + any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension and the spatial size of each dimension must be a multiple of 32 if is_pad parameter is False. Please notice each output of backbone must be 2x downsample in spatial dimension @@ -248,10 +250,11 @@ def __init__( Args: in_channels: number of input channels. out_channels: number of output channels. - backbone: name of backbones to initialize, only support efficientnet right now, - can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. - pretrained: whether to initialize pretrained ImageNet weights, only available - for spatial_dims=2 and batch norm is used, default to False. + backbone: name of backbones to initialize, only support efficientnet and resnet right now, + can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200]. + pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks + if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks + if spatial_dims=3 and in_channels=1. Default to False. decoder_channels: number of output channels for all feature maps in decoder. `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default to (256, 128, 64, 32, 16). diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index e71f8d193d..4959f0713f 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -36,6 +36,7 @@ class HighResBlock(nn.Module): + def __init__( self, spatial_dims: int, diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 3ec1cea37e..5f340c9be6 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -49,6 +49,7 @@ class _DenseLayerDecoder(nn.Module): + def __init__( self, num_features: int, @@ -103,6 +104,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class _DecoderBlock(nn.Sequential): + def __init__( self, layers: int, @@ -159,6 +161,7 @@ def __init__( class _DenseLayer(nn.Sequential): + def __init__( self, num_features: int, @@ -219,6 +222,7 @@ def __init__( class _Transition(nn.Sequential): + def __init__( self, in_channels: int, act: str | tuple = ("relu", {"inplace": True}), norm: str | tuple = "batch" ) -> None: @@ -235,6 +239,7 @@ def __init__( class _ResidualBlock(nn.Module): + def __init__( self, layers: int, @@ -312,6 +317,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class _DecoderBranch(nn.ModuleList): + def __init__( self, decode_config: Sequence[int] = (8, 4), diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 0a25b7feec..ad6b77bf3d 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -83,6 +83,7 @@ def __init__( if mil_mode == "att_trans_pyramid": # register hooks to capture outputs of intermediate layers def forward_hook(layer_name): + def hook(module, input, output): self.extra_outputs[layer_name] = output diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index a7c5158240..4d6150ea1b 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -234,6 +234,7 @@ def forward(self, x): class AffineHead(nn.Module): + def __init__( self, spatial_dims: int, @@ -375,6 +376,7 @@ def build_output_block(self): class AdditiveUpSampleBlock(nn.Module): + def __init__( self, spatial_dims: int, diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 34a4b7057e..99975271da 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from monai.networks.blocks.encoder import BaseEncoder from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer from monai.utils import ensure_tuple_rep @@ -45,6 +46,19 @@ "resnet200", ] + +resnet_params = { + # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) + "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), + "resnet18": ("basic", [2, 2, 2, 2], "A", True, True), + "resnet34": ("basic", [3, 4, 6, 3], "A", True, True), + "resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True), + "resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False), + "resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False), + "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False), +} + + logger = logging.getLogger(__name__) @@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class ResNetFeatures(ResNet): + + def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None: + """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for + segmentation and objection models. + + Compared with the class `ResNet`, the only different place is the forward function. + + Args: + model_name: name of model to initialize, can be from [resnet10, ..., resnet200]. + pretrained: whether to initialize pretrained MedicalNet weights, + only available for spatial_dims=3 and in_channels=1. + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of input channels for first convolutional layer. + """ + if model_name not in resnet_params: + model_name_string = ", ".join(resnet_params.keys()) + raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") + + block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name] + + super().__init__( + block=block, + layers=layers, + block_inplanes=get_inplanes(), + spatial_dims=spatial_dims, + n_input_channels=in_channels, + conv1_t_stride=2, + shortcut_type=shortcut_type, + feed_forward=False, + bias_downsample=bias_downsample, + ) + if pretrained: + if spatial_dims == 3 and in_channels == 1: + _load_state_dict(self, model_name, datasets23=datasets23) + else: + raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs: input should have spatially N dimensions + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. + + Returns: + a list of torch Tensors. + """ + x = self.conv1(inputs) + x = self.bn1(x) + x = self.relu(x) + + features = [] + features.append(x) + + if not self.no_max_pool: + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + return features + + +class ResNetEncoder(ResNetFeatures, BaseEncoder): + """Wrap the original resnet to an encoder for flexible-unet.""" + + backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] + + @classmethod + def get_encoder_parameters(cls) -> list[dict]: + """Get the initialization parameter for resnet backbones.""" + parameter_list = [] + for backbone_name in cls.backbone_names: + parameter_list.append( + {"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1} + ) + return parameter_list + + @classmethod + def num_channels_per_output(cls) -> list[tuple[int, ...]]: + """Get number of resnet backbone output feature maps channel.""" + return [ + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + ] + + @classmethod + def num_outputs(cls) -> list[int]: + """Get number of resnet backbone output feature maps. + + Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`. + """ + return [5] * 7 + + @classmethod + def get_encoder_names(cls) -> list[str]: + """Get names of resnet backbones.""" + return cls.backbone_names + + def _resnet( arch: str, block: type[ResNetBlock | ResNetBottleneck], @@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): """ - Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 @@ -533,7 +661,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ Return correct shortcut_type and bias_downsample - for pretrained MedicalNet weights according to resnet depth + for pretrained MedicalNet weights according to resnet depth. """ # After testing # False: 10, 50, 101, 152, 200 @@ -541,3 +669,16 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type + + +def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None: + search_res = re.search(r"resnet(\d+)", model_name) + if search_res: + resnet_depth = int(search_res.group(1)) + datasets23 = model_name.endswith("_23datasets") + else: + raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.") + + model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23) + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} + model.load_state_dict(model_state_dict) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index e064c19740..0949e307b9 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -17,9 +17,9 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution +from monai.networks.blocks import Convolution, Upsample from monai.networks.blocks.spade_norm import SPADE -from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample +from monai.networks.nets.autoencoderkl import AttentionBlock, Encoder from monai.utils import ensure_tuple_rep __all__ = ["SPADEAutoencoderKL"] @@ -136,7 +136,6 @@ class SPADEDecoder(nn.Module): attention_levels: indicate which level from channels contain an attention block. label_nc: number of semantic channels for SPADE normalisation. with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer. """ @@ -152,7 +151,6 @@ def __init__( attention_levels: Sequence[bool], label_nc: int, with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -197,12 +195,11 @@ def __init__( ) ) blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -241,17 +238,36 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - _AttentionBlock( + AttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) if not is_final_block: - blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=block_in_ch, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + blocks.append( + Upsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=block_in_ch, + out_channels=block_in_ch, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + ) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -297,7 +313,6 @@ class SPADEAutoencoderKL(nn.Module): norm_eps: epsilon for the normalization. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer. """ @@ -315,7 +330,6 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -336,11 +350,6 @@ def __init__( "`channels`." ) - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, @@ -351,7 +360,6 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, - use_flash_attention=use_flash_attention, ) self.decoder = SPADEDecoder( spatial_dims=spatial_dims, @@ -364,7 +372,6 @@ def __init__( attention_levels=attention_levels, label_nc=label_nc, with_nonlocal_attn=with_decoder_nonlocal_attn, - use_flash_attention=use_flash_attention, spade_intermediate_channels=spade_intermediate_channels, ) self.quant_conv_mu = Convolution( diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index d89eb8ae03..2815224e08 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -30,6 +30,7 @@ def get_acti_layer(act: tuple[str, dict] | str, nchan: int = 0): class LUConv(nn.Module): + def __init__(self, spatial_dims: int, nchan: int, act: tuple[str, dict] | str, bias: bool = False): super().__init__() @@ -58,6 +59,7 @@ def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: tuple[str, dict] class InputTransition(nn.Module): + def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False ): @@ -91,6 +93,7 @@ def forward(self, x): class DownTransition(nn.Module): + def __init__( self, spatial_dims: int, @@ -127,6 +130,7 @@ def forward(self, x): class UpTransition(nn.Module): + def __init__( self, spatial_dims: int, @@ -165,6 +169,7 @@ def forward(self, x, skipx): class OutputTransition(nn.Module): + def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False ): diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 42e537648a..ecf237a2ff 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -42,6 +42,7 @@ "predict_segmentation", "normalize_transform", "to_norm_affine", + "CastTempType", "normal_init", "icnr_init", "pixelshuffle", @@ -840,7 +841,6 @@ def _onnx_trt_compile( # set up the conversion configuration config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) config.add_optimization_profile(profile) if precision == "fp16": config.set_flag(trt.BuilderFlag.FP16) @@ -850,7 +850,10 @@ def _onnx_trt_compile( # wrap the serialized TensorRT engine back to a TorchScript module. trt_model = torch_tensorrt.ts.embed_engine_in_new_module( - f.getvalue(), torch.device(f"cuda:{device}"), input_names, output_names + f.getvalue(), + device=torch.device(f"cuda:{device}"), + input_binding_names=input_names, + output_binding_names=output_names, ) return trt_model @@ -1164,3 +1167,24 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") + + +class CastTempType(nn.Module): + """ + Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type. + """ + + def __init__(self, initial_type, temporary_type, submodule): + super().__init__() + self.initial_type = initial_type + self.temporary_type = temporary_type + self.submodule = submodule + + def forward(self, x): + dtype = x.dtype + if dtype == self.initial_type: + x = x.to(self.temporary_type) + x = self.submodule(x) + if dtype == self.initial_type: + x = x.to(self.initial_type) + return x diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 75e108ae71..045135628d 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -43,6 +43,7 @@ class DataLoaderIter: + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: if not isinstance(data_loader, DataLoader): raise ValueError( @@ -71,6 +72,7 @@ def __next__(self): class TrainDataLoaderIter(DataLoaderIter): + def __init__( self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True ) -> None: diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index 7e566abb46..75a125f076 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -70,12 +70,14 @@ def generate_param_groups( lr_values = ensure_tuple_rep(lr_values, len(layer_matches)) def _get_select(f): + def _select(): return f(network).parameters() return _select def _get_filter(f): + def _filter(): # should eventually generate a list of network parameters return (x[1] for x in filter(f, network.named_parameters())) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..ab9adb6a99 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -92,6 +92,7 @@ from .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd from .intensity.array import ( AdjustContrast, + ClipIntensityPercentiles, ComputeHoVerMaps, DetectEnvelope, ForegroundMask, @@ -135,6 +136,9 @@ AdjustContrastd, AdjustContrastD, AdjustContrastDict, + ClipIntensityPercentilesd, + ClipIntensityPercentilesD, + ClipIntensityPercentilesDict, ComputeHoVerMapsd, ComputeHoVerMapsD, ComputeHoVerMapsDict, @@ -336,6 +340,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .regularization.array import CutMix, CutOut, MixUp +from .regularization.dictionary import ( + CutMixd, + CutMixD, + CutMixDict, + CutOutd, + CutOutD, + CutOutDict, + MixUpd, + MixUpD, + MixUpDict, +) from .signal.array import ( SignalContinuousWavelet, SignalFillEmpty, diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index 5729740690..f5f1a4fc18 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -132,6 +132,7 @@ def __call__(self, img, seg): @_monai_export("monai.transforms") def adaptor(function, outputs, inputs=None): + def must_be_types_or_none(variable_name, variable, types): if variable is not None: if not isinstance(variable, types): @@ -216,6 +217,7 @@ def _inner(ditems): @_monai_export("monai.transforms") def apply_alias(fn, name_map): + def _inner(data): # map names pre_call = dict(data) @@ -236,6 +238,7 @@ def _inner(data): @_monai_export("monai.transforms") def to_kwargs(fn): + def _inner(data): return fn(**data) @@ -243,6 +246,7 @@ def _inner(data): class FunctionSignature: + def __init__(self, function: Callable) -> None: import inspect diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f9667402c9..f656475a36 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -30,7 +30,7 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform -from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils.enums import TransformBackends from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple @@ -54,6 +54,7 @@ "NormalizeIntensity", "ThresholdIntensity", "ScaleIntensityRange", + "ClipIntensityPercentiles", "AdjustContrast", "RandAdjustContrast", "ScaleIntensityRangePercentiles", @@ -91,24 +92,33 @@ class RandGaussianNoise(RandomizableTransform): mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. dtype: output data type, if None, same as input image. defaults to float32. + sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None: + def __init__( + self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1, + dtype: DtypeLike = np.float32, + sample_std: bool = True, + ) -> None: RandomizableTransform.__init__(self, prob) self.mean = mean self.std = std self.dtype = dtype self.noise: np.ndarray | None = None + self.sample_std = sample_std def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None: super().randomize(None) if not self._do_transform: return None - rand_std = self.R.uniform(0, self.std) - noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape) + std = self.R.uniform(0, self.std) if self.sample_std else self.std + noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape) # noise is float64 array, convert to the output dtype to save memory self.noise, *_ = convert_data_type(noise, dtype=self.dtype) @@ -998,6 +1008,151 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return ret +class ClipIntensityPercentiles(Transform): + """ + Apply clip based on the intensity distribution of input image. + If `sharpness_factor` is provided, the intensity values will be soft clipped according to + f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)) + From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291 + + Soft clipping preserves the order of the values and maintains the gradient everywhere. + For example: + + .. code-block:: python + :emphasize-lines: 11, 22 + + image = torch.Tensor( + [[[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]]) + + # Hard clipping from lower and upper image intensity percentiles + hard_clipper = ClipIntensityPercentiles(30, 70) + print(hard_clipper(image)) + metatensor([[[2., 2., 3., 4., 4.], + [2., 2., 3., 4., 4.], + [2., 2., 3., 4., 4.], + [2., 2., 3., 4., 4.], + [2., 2., 3., 4., 4.], + [2., 2., 3., 4., 4.]]]) + + + # Soft clipping from lower and upper image intensity percentiles + soft_clipper = ClipIntensityPercentiles(30, 70, 10.) + print(soft_clipper(image)) + metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000], + [2.0000, 2.0693, 3.0000, 3.9307, 4.0000], + [2.0000, 2.0693, 3.0000, 3.9307, 4.0000], + [2.0000, 2.0693, 3.0000, 3.9307, 4.0000], + [2.0000, 2.0693, 3.0000, 3.9307, 4.0000], + [2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]]) + + See Also: + + - :py:class:`monai.transforms.ScaleIntensityRangePercentiles` + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + lower: float | None, + upper: float | None, + sharpness_factor: float | None = None, + channel_wise: bool = False, + return_clipping_values: bool = False, + dtype: DtypeLike = np.float32, + ) -> None: + """ + Args: + lower: lower intensity percentile. In the case of hard clipping, None will have the same effect as 0 by + not clipping the lowest input values. However, in the case of soft clipping, None and zero will have + two different effects: None will not apply clipping to low values, whereas zero will still transform + the lower values according to the soft clipping transformation. Please check for more details: + https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291. + upper: upper intensity percentile. The same as for lower, but this time with the highest values. If we + are looking to perform soft clipping, if None then there will be no effect on this side whereas if set + to 100, the values will be passed via the corresponding clipping equation. + sharpness_factor: if not None, the intensity values will be soft clipped according to + f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)). + defaults to None. + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. + return_clipping_values: whether to return the calculated percentiles in tensor meta information. + If soft clipping and requested percentile is None, return None as the corresponding clipping + values in meta information. Clipping values are stored in a list with each element corresponding + to a channel if channel_wise is set to True. defaults to False. + dtype: output data type, if None, same as input image. defaults to float32. + """ + if lower is None and upper is None: + raise ValueError("lower or upper percentiles must be provided") + if lower is not None and (lower < 0.0 or lower > 100.0): + raise ValueError("Percentiles must be in the range [0, 100]") + if upper is not None and (upper < 0.0 or upper > 100.0): + raise ValueError("Percentiles must be in the range [0, 100]") + if upper is not None and lower is not None and upper < lower: + raise ValueError("upper must be greater than or equal to lower") + if sharpness_factor is not None and sharpness_factor <= 0: + raise ValueError("sharpness_factor must be greater than 0") + + self.lower = lower + self.upper = upper + self.sharpness_factor = sharpness_factor + self.channel_wise = channel_wise + if return_clipping_values: + self.clipping_values: list[tuple[float | None, float | None]] = [] + self.return_clipping_values = return_clipping_values + self.dtype = dtype + + def _clip(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + if self.sharpness_factor is not None: + lower_percentile = percentile(img, self.lower) if self.lower is not None else None + upper_percentile = percentile(img, self.upper) if self.upper is not None else None + img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype) + else: + lower_percentile = percentile(img, self.lower) if self.lower is not None else percentile(img, 0) + upper_percentile = percentile(img, self.upper) if self.upper is not None else percentile(img, 100) + img = clip(img, lower_percentile, upper_percentile) + + if self.return_clipping_values: + self.clipping_values.append( + ( + ( + lower_percentile + if lower_percentile is None + else lower_percentile.item() if hasattr(lower_percentile, "item") else lower_percentile + ), + ( + upper_percentile + if upper_percentile is None + else upper_percentile.item() if hasattr(upper_percentile, "item") else upper_percentile + ), + ) + ) + img = convert_to_tensor(img, track_meta=False) + return img + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + img = convert_to_tensor(img, track_meta=get_track_meta()) + img_t = convert_to_tensor(img, track_meta=False) + if self.channel_wise: + img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore + else: + img_t = self._clip(img=img_t) + + img = convert_to_dst_type(img_t, dst=img)[0] + if self.return_clipping_values: + img.meta["clipping_values"] = self.clipping_values # type: ignore + + return img + + class AdjustContrast(Transform): """ Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:: @@ -1831,15 +1986,19 @@ class RandGibbsNoise(RandomizableTransform): Args: prob (float): probability of applying the transform. - alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. + If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. """ backend = GibbsNoise.backend - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: + def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None: + if isinstance(alpha, float): + alpha = (0, alpha) + alpha = ensure_tuple(alpha) if len(alpha) != 2: raise ValueError("alpha length must be 2.") if alpha[1] > 1 or alpha[0] < 0: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 058ef87b95..5dbac485fe 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -26,6 +26,7 @@ from monai.data.meta_obj import get_track_meta from monai.transforms.intensity.array import ( AdjustContrast, + ClipIntensityPercentiles, ComputeHoVerMaps, ForegroundMask, GaussianSharpen, @@ -77,6 +78,7 @@ "NormalizeIntensityd", "ThresholdIntensityd", "ScaleIntensityRanged", + "ClipIntensityPercentilesd", "AdjustContrastd", "RandAdjustContrastd", "ScaleIntensityRangePercentilesd", @@ -122,6 +124,8 @@ "ThresholdIntensityDict", "ScaleIntensityRangeD", "ScaleIntensityRangeDict", + "ClipIntensityPercentilesD", + "ClipIntensityPercentilesDict", "AdjustContrastD", "AdjustContrastDict", "RandAdjustContrastD", @@ -172,7 +176,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`. - Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add + Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add different noise for every field, please use this transform separately. Args: @@ -183,6 +187,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): std: Standard deviation (spread) of distribution. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. """ backend = RandGaussianNoise.backend @@ -195,10 +200,11 @@ def __init__( std: float = 0.1, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + sample_std: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype) + self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -884,6 +890,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ClipIntensityPercentilesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ClipIntensityPercentiles`. + Clip the intensity values of input image to a specific range based on the intensity distribution of the input. + If `sharpness_factor` is provided, the intensity values will be soft clipped according to + f(x) = x + (1/sharpness_factor) * softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)) + """ + + def __init__( + self, + keys: KeysCollection, + lower: float | None, + upper: float | None, + sharpness_factor: float | None = None, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.scaler = ClipIntensityPercentiles( + lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype + ) + + def __call__(self, data: dict) -> dict: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.scaler(d[key]) + return d + + class AdjustContrastd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AdjustContrast`. @@ -1421,10 +1457,11 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): keys: 'image', 'label', or ['image', 'label'] depending on which data you need to transform. prob (float): probability of applying the transform. - alpha (float, List[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. + If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. allow_missing_keys: do not raise exception if key is missing. """ @@ -1434,7 +1471,7 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - alpha: Sequence[float] = (0.0, 1.0), + alpha: float | Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1926,6 +1963,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N NormalizeIntensityD = NormalizeIntensityDict = NormalizeIntensityd ThresholdIntensityD = ThresholdIntensityDict = ThresholdIntensityd ScaleIntensityRangeD = ScaleIntensityRangeDict = ScaleIntensityRanged +ClipIntensityPercentilesD = ClipIntensityPercentilesDict = ClipIntensityPercentilesd AdjustContrastD = AdjustContrastDict = AdjustContrastd RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index 73149f1be5..1a7d16fb8c 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -30,6 +30,7 @@ class _BatchInverseDataset(Dataset): + def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None: self.data = data self.invertible_transform = transform diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 3d5d30be92..da9b23ce57 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -631,6 +631,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: class Ensemble: + @staticmethod def get_stacked_torch(img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> torch.Tensor: """Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.""" diff --git a/monai/transforms/regularization/__init__.py b/monai/transforms/regularization/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/regularization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py new file mode 100644 index 0000000000..0b495c8623 --- /dev/null +++ b/monai/transforms/regularization/array.py @@ -0,0 +1,174 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from math import ceil, sqrt + +import torch + +from ..transform import RandomizableTransform + +__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] + + +class Mixer(RandomizableTransform): + + def __init__(self, batch_size: int, alpha: float = 1.0) -> None: + """ + Mixer is a base class providing the basic logic for the mixup-class of + augmentations. In all cases, we need to sample the mixing weights for each + sample (lambda in the notation used in the papers). Also, pairs of samples + being mixed are picked by randomly shuffling the batch samples. + + Args: + batch_size (int): number of samples per batch. That is, samples are expected tp + be of size batchsize x channels [x depth] x height x width. + alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha) + distribution. Defaults to 1.0, the uniform distribution. + """ + super().__init__() + if alpha <= 0: + raise ValueError(f"Expected positive number, but got {alpha = }") + self.alpha = alpha + self.batch_size = batch_size + + @abstractmethod + def apply(self, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: + """ + Sometimes you need may to apply the same transform to different tensors. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. + """ + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) + + +class MixUp(Mixer): + """MixUp as described in: + Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. + mixup: Beyond Empirical Risk Minimization, ICLR 2018 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. + """ + + def apply(self, data: torch.Tensor): + weight, perm = self._params + nsamples, *dims = data.shape + if len(weight) != nsamples: + raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") + + if len(dims) not in [3, 4]: + raise ValueError("Unexpected number of dimensions") + + mixweight = weight[(Ellipsis,) + (None,) * len(dims)] + return mixweight * data + (1 - mixweight) * data[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + if labels is None: + return self.apply(data) + return self.apply(data), self.apply(labels) + + +class CutMix(Mixer): + """CutMix augmentation as described in: + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, + ICCV 2019 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles used during for mixing. + Please refer to the paper for details. + + The most common use case is something close to: + + .. code-block:: python + + cm = CutMix(batch_size=8, alpha=0.5) + for batch in loader: + images, labels = batch + augimg, auglabels = cm(images, labels) + output = model(augimg) + loss = loss_function(output, auglabels) + ... + + """ + + def apply(self, data: torch.Tensor): + weights, perm = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + (1 - mask) * data[perm, ...] + + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + + +class CutOut(Mixer): + """Cutout as described in the paper: + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, + arXiv:1708.04552 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles being cut put. + Please refer to the paper for details. + """ + + def apply(self, data: torch.Tensor): + weights, _ = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + + def __call__(self, data: torch.Tensor): + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py new file mode 100644 index 0000000000..373913da99 --- /dev/null +++ b/monai/transforms/regularization/dictionary.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.utils.misc import ensure_tuple + +from ..transform import MapTransform +from .array import CutMix, CutOut, MixUp + +__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] + + +class MixUpd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.MixUp`. + + Notice that the mixup transformation will be the same for all entries + for consistency, i.e. images and labels must be applied the same augmenation. + """ + + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixup = MixUp(batch_size, alpha) + + def __call__(self, data): + self.mixup.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixup.apply(data[k]) + return result + + +class CutMixd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutMix`. + + Notice that the mixture weights will be the same for all entries + for consistency, i.e. images and labels must be aggregated with the same weights, + but the random crops are not. + """ + + def __init__( + self, + keys: KeysCollection, + batch_size: int, + label_keys: KeysCollection | None = None, + alpha: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixer = CutMix(batch_size, alpha) + self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + + def __call__(self, data): + self.mixer.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixer.apply(data[k]) + for k in self.label_keys: + result[k] = self.mixer.apply_on_labels(data[k]) + return result + + +class CutOutd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutOut`. + + Notice that the cutout is different for every entry in the dictionary. + """ + + def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.cutout = CutOut(batch_size) + + def __call__(self, data): + result = dict(data) + self.cutout.randomize() + for k in self.keys: + result[k] = self.cutout(data[k]) + return result + + +MixUpD = MixUpDict = MixUpd +CutMixD = CutMixDict = CutMixd +CutOutD = CutOutDict = CutOutd diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py index c9df5f1dbb..9d19263f8b 100644 --- a/monai/transforms/smooth_field/array.py +++ b/monai/transforms/smooth_field/array.py @@ -96,7 +96,7 @@ def __init__( self.set_spatial_size(spatial_size) def randomize(self, data: Any | None = None) -> None: - self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) + self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) # type: ignore[index] def set_spatial_size(self, spatial_size: Sequence[int] | None) -> None: """ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8ad86b72dd..094afdd3c4 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3441,7 +3441,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl idx = self.R.permutation(image_np.shape[0]) idx = idx[: self.num_patches] idx_np = convert_data_type(idx, np.ndarray)[0] - image_np = image_np[idx] # type: ignore + image_np = image_np[idx] locations = locations[idx_np] return image_np, locations elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1cd9ff6323..7e3a7b0454 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -670,6 +670,7 @@ def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Sequence[bool] self.use_re = ensure_tuple_rep(use_re, len(self.keys)) def __call__(self, data): + def _delete_item(keys, d, use_re: bool = False): key = keys[0] if len(keys) > 1: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e282ecff24..560dbac346 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -38,6 +38,7 @@ nonzero, ravel, searchsorted, + softplus, unique, unravel_index, where, @@ -131,9 +132,45 @@ "resolves_modes", "has_status_keys", "distance_transform_edt", + "soft_clip", ] +def soft_clip( + arr: NdarrayOrTensor, + sharpness_factor: float = 1.0, + minv: NdarrayOrTensor | float | int | None = None, + maxv: NdarrayOrTensor | float | int | None = None, + dtype: DtypeLike | torch.dtype = np.float32, +) -> NdarrayOrTensor: + """ + Apply soft clip to the input array or tensor. + The intensity values will be soft clipped according to + f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)) + From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291 + + To perform one-sided clipping, set either minv or maxv to None. + Args: + arr: input array to clip. + sharpness_factor: the sharpness of the soft clip function, default to 1. + minv: minimum value of target clipped array. + maxv: maximum value of target clipped array. + dtype: if not None, convert input array to dtype before computation. + + """ + + if dtype is not None: + arr, *_ = convert_data_type(arr, dtype=dtype) + + v = arr + if minv is not None: + v = v + softplus(-sharpness_factor * (arr - minv)) / sharpness_factor + if maxv is not None: + v = v - softplus(sharpness_factor * (arr - maxv)) / sharpness_factor + + return v + + def rand_choice(prob: float = 0.5) -> bool: """ Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance. @@ -625,9 +662,12 @@ def generate_label_classes_crop_centers( for i, array in enumerate(indices): if len(array) == 0: - ratios_[i] = 0 - if warn: - warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") + if ratios_[i] != 0: + ratios_[i] = 0 + if warn: + warnings.warn( + f"no available indices of class {i} to crop, setting the crop ratio of this class to zero." + ) centers = [] classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_)) @@ -2150,7 +2190,7 @@ def distance_transform_edt( if return_distances: dtype = torch.float64 if float64_distances else torch.float32 if distances is None: - distances = torch.zeros_like(img, dtype=dtype) # type: ignore + distances = torch.zeros_like(img, memory_format=torch.contiguous_format, dtype=dtype) # type: ignore else: if not isinstance(distances, torch.Tensor) and distances.device != img.device: raise TypeError("distances must be a torch.Tensor on the same device as img") diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 0774d50314..020d99af16 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -52,9 +52,24 @@ "median", "mean", "std", + "softplus", ] +def softplus(x: NdarrayOrTensor) -> NdarrayOrTensor: + """stable softplus through `np.logaddexp` with equivalent implementation for torch. + + Args: + x: array/tensor. + + Returns: + Softplus of the input. + """ + if isinstance(x, np.ndarray): + return np.logaddexp(np.zeros_like(x), x) + return torch.logaddexp(torch.zeros_like(x), x) + + def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool: """`np.allclose` with equivalent implementation for torch.""" b, *_ = convert_to_dst_type(b, a, wrap_sequence=True) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index a0847dd76c..b786e92151 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -411,7 +411,7 @@ class CompInitMode(StrEnum): """ DEFAULT = "default" - PARTIAL = "partial" + CALLABLE = "callable" DEBUG = "debug" diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 81f582daef..886103a0ab 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -527,7 +527,7 @@ def doc_images() -> str | None: @staticmethod def algo_hash() -> str | None: - return os.environ.get("MONAI_ALGO_HASH", "249bf4b") + return os.environ.get("MONAI_ALGO_HASH", "4403f94") @staticmethod def trace_transform() -> str | None: @@ -742,6 +742,7 @@ def check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any, class CheckKeyDuplicatesYamlLoader(SafeLoader): + def construct_mapping(self, node, deep=False): mapping = set() for key_node, _ in node.value: diff --git a/monai/utils/module.py b/monai/utils/module.py index f46ba7c1b3..6f301d8067 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -209,7 +209,7 @@ def load_submodules( if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None: try: mod = import_module(name) - importer.find_module(name).load_module(name) # type: ignore + importer.find_spec(name).loader.load_module(name) # type: ignore submodules.append(mod) except OptionalImportError: pass # could not import the optional deps., they are ignored @@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: Args: __path: if a string is provided, it's interpreted as the full path of the target class or function component. - If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned. + If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``. + For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided, + as ``functools.partial(__path, **kwargs)`` for future invoking. + __mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``: - ``"default"``: returns ``component(**kwargs)`` - - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)`` - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` kwargs: keyword arguments to the callable represented by ``__path``. @@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return component if m == CompInitMode.DEFAULT: return component(**kwargs) - if m == CompInitMode.PARTIAL: - return partial(component, **kwargs) + if m == CompInitMode.CALLABLE: + return partial(component, **kwargs) if kwargs else component if m == CompInitMode.DEBUG: warnings.warn( f"\n\npdb: instantiating component={component}, mode={m}\n" @@ -269,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return pdb.runcall(component, **kwargs) except Exception as e: raise RuntimeError( - f"Failed to instantiate component '{__path}' with kwargs: {kwargs}" + f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}" f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode." ) from e @@ -418,6 +421,7 @@ def optional_import( msg += f" ({exception_str})" class _LazyRaise: + def __init__(self, *_args, **_kwargs): _default_msg = ( f"{msg}." @@ -453,6 +457,7 @@ def __iter__(self): return _LazyRaise(), False class _LazyCls(_LazyRaise): + def __init__(self, *_args, **kwargs): super().__init__() if not as_type.startswith("decorator"): diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py index da5c0ac05c..5c880bbe1f 100644 --- a/monai/utils/profiling.py +++ b/monai/utils/profiling.py @@ -336,6 +336,7 @@ def profile_iter(self, name, iterable): """Wrapper around anything iterable to profile how long it takes to generate items.""" class _Iterable: + def __iter__(_self): # noqa: B902, N805 pylint: disable=E0213 do_iter = True orig_iter = iter(iterable) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 81d0bb32c4..6d1e8dfd03 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -96,12 +96,14 @@ def __init__( warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") def backward_hook(self, name): + def _hook(_module, _grad_input, grad_output): self.gradients[name] = grad_output[0] return _hook def forward_hook(self, name): + def _hook(_module, _input, output): self.activations[name] = output diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index e2e938d86b..c54c9cd4ca 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -26,6 +26,7 @@ class _AutoGradReLU(torch.autograd.Function): + @staticmethod def forward(ctx, x): pos_mask = (x > 0).type_as(x) diff --git a/pyproject.toml b/pyproject.toml index cd8a510b04..50d0b09672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,8 @@ exclude = "monai/bundle/__main__.py" [tool.ruff] line-length = 133 -ignore-init-module-imports = true -ignore = ["F401", "E741"] +lint.ignore-init-module-imports = true +lint.ignore = ["F401", "E741"] [tool.pytype] # Space-separated list of files or directories to exclude. diff --git a/requirements-dev.txt b/requirements-dev.txt index f8bc9d5a3e..b207b56b19 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,17 +1,17 @@ # Full requirements for developments -r requirements-min.txt pytorch-ignite==0.4.11 -gdown>=4.4.0, <=4.6.3 +gdown>=4.7.3 scipy>=1.7.1 itk>=5.2 nibabel pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 -tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776 +tensorboard>=2.12.0 # https://github.com/Project-MONAI/MONAI/issues/7434 scikit-image>=0.19.0 tqdm>=4.47.0 lmdb flake8>=3.8.1 -flake8-bugbear +flake8-bugbear<=24.2.6 # https://github.com/Project-MONAI/MONAI/issues/7690 flake8-comprehensions mccabe pep8-naming @@ -26,7 +26,7 @@ mypy>=1.5.0 ninja torchvision psutil -cucim>=23.2.0; platform_system == "Linux" +cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10" openslide-python imagecodecs; platform_system == "Linux" or platform_system == "Darwin" tifffile; platform_system == "Linux" or platform_system == "Darwin" @@ -34,7 +34,7 @@ pandas requests einops transformers>=4.36.0 -mlflow>=1.28.0 +mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 matplotlib!=3.5.0 tensorboardX @@ -46,7 +46,7 @@ pynrrd pre-commit pydicom h5py -nni; platform_system == "Linux" +nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded onnx>=1.13.0 @@ -57,3 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub +opencv-python-headless diff --git a/runtests.sh b/runtests.sh index 0c60bc0f58..0b3e20ce49 100755 --- a/runtests.sh +++ b/runtests.sh @@ -738,12 +738,14 @@ fi # network training/inference/eval integration tests if [ $doNetTests = true ] then + set +e # disable exit on failure so that diagnostics can be given on failure echo "${separator}${blue}integration${noColor}" for i in tests/*integration_*.py do echo "$i" ${cmdPrefix}${cmd} "$i" done + set -e # enable exit on failure fi # run model zoo tests diff --git a/setup.cfg b/setup.cfg index 4180ced917..c8ae1630f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,21 +52,21 @@ all = scipy>=1.7.1 pillow tensorboard - gdown==4.6.3 + gdown>=4.7.3 pytorch-ignite==0.4.11 torchvision itk>=5.2 tqdm>=4.47.0 lmdb psutil - cucim>=23.2.0 + cucim-cu12; python_version >= '3.9' and python_version <= '3.10' openslide-python tifffile imagecodecs pandas einops transformers<4.22; python_version <= '3.10' - mlflow>=1.28.0 + mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 matplotlib tensorboardX @@ -97,7 +97,7 @@ pillow = tensorboard = tensorboard gdown = - gdown==4.6.3 + gdown>=4.7.3 ignite = pytorch-ignite==0.4.11 torchvision = @@ -111,7 +111,7 @@ lmdb = psutil = psutil cucim = - cucim>=23.2.0 + cucim-cu12 openslide = openslide-python tifffile = @@ -125,7 +125,7 @@ einops = transformers = transformers<4.22; python_version <= '3.10' mlflow = - mlflow + mlflow>=1.28.0, <=2.11.3 matplotlib = matplotlib clearml = diff --git a/tests/croppers.py b/tests/croppers.py index 8c9b43bf0a..cfececfa9f 100644 --- a/tests/croppers.py +++ b/tests/croppers.py @@ -24,6 +24,7 @@ class CropTest(unittest.TestCase): + @staticmethod def get_arr(shape): return np.random.randint(100, size=shape).astype(float) diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index c7baac2bc9..78c6ca06bc 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -21,6 +21,7 @@ class HvdEvenlyDivisibleAllGather: + def test_data(self): # initialize Horovod hvd.init() diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index ba35f2b80c..01dc044870 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -70,6 +70,7 @@ @skip_if_windows class TestNgcBundleDownload(unittest.TestCase): + @parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2]) @skip_if_quick def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val): @@ -101,6 +102,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download @unittest.skip("deprecating mmar tests") class TestAllDownloadingMMAR(unittest.TestCase): + def setUp(self): print_debug_info() self.test_dir = "./" diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 7b5328bf72..b2c44c12c6 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -36,8 +36,8 @@ class NonConfigWorkflow(BundleWorkflow): """ - def __init__(self, filename, output_dir): - super().__init__(workflow_type="inference") + def __init__(self, filename, output_dir, meta_file=None, logging_file=None): + super().__init__(workflow_type="inference", meta_file=meta_file, logging_file=logging_file) self.filename = filename self.output_dir = output_dir self._bundle_root = "will override" diff --git a/tests/padders.py b/tests/padders.py index ae1153bdfd..a7dce263bb 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -51,6 +51,7 @@ class PadTest(unittest.TestCase): + @staticmethod def get_arr(shape): return np.random.randint(100, size=shape).astype(float) diff --git a/tests/profile_subclass/min_classes.py b/tests/profile_subclass/min_classes.py index 7104ffcd59..3e7c52476f 100644 --- a/tests/profile_subclass/min_classes.py +++ b/tests/profile_subclass/min_classes.py @@ -25,5 +25,6 @@ class SubTensor(torch.Tensor): class SubWithTorchFunc(torch.Tensor): + def __torch_function__(self, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, {} if kwargs is None else kwargs) diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py index 2f3783cbb8..1cbf3ea168 100644 --- a/tests/test_acn_block.py +++ b/tests/test_acn_block.py @@ -29,6 +29,7 @@ class TestACNBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_acn_block(self, input_param, input_shape, expected_shape): net = ActiConvNormBlock(**input_param) diff --git a/tests/test_activations.py b/tests/test_activations.py index 0e83c73304..ad18e2bbec 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -94,6 +94,7 @@ class TestActivations(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 22a275997c..74968c0bb4 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -50,6 +50,7 @@ class TestActivationsd(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = Activationsd(**input_param)(test_input) diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py index 257c4346ad..2495fdc72e 100644 --- a/tests/test_adaptors.py +++ b/tests/test_adaptors.py @@ -18,13 +18,16 @@ class TestAdaptors(unittest.TestCase): + def test_function_signature(self): + def foo(image, label=None, *a, **kw): pass _ = FunctionSignature(foo) def test_single_in_single_out(self): + def foo(image): return image * 2 @@ -55,6 +58,7 @@ def foo(image): self.assertEqual(dres["img"], 4) def test_multi_in_single_out(self): + def foo(image, label): return image * label @@ -86,6 +90,7 @@ def foo(image, label): self.assertEqual(dres["lbl"], 3) def test_default_arg_single_out(self): + def foo(a, b=2): return a * b @@ -98,6 +103,7 @@ def foo(a, b=2): dres = adaptor(foo, "c")(d) def test_multi_out(self): + def foo(a, b): return a * b, a / b @@ -107,6 +113,7 @@ def foo(a, b): self.assertEqual(dres["d"], 3 / 4) def test_dict_out(self): + def foo(a): return {"a": a * 2} @@ -120,7 +127,9 @@ def foo(a): class TestApplyAlias(unittest.TestCase): + def test_apply_alias(self): + def foo(d): d["x"] *= 2 return d @@ -131,7 +140,9 @@ def foo(d): class TestToKwargs(unittest.TestCase): + def test_to_kwargs(self): + def foo(**kwargs): results = {k: v * 2 for k, v in kwargs.items()} return results diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index cd33f98fd5..199fe071e3 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -29,6 +29,7 @@ class TestAddCoordinateChannels(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannels(**input_param)(input) diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index f5784928fd..c00240c2d5 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -42,6 +42,7 @@ class TestAddCoordinateChannels(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannelsd(**input_param)(input)["img"] diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index 140caa34ba..c453322d6b 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -69,6 +69,7 @@ class TestAddExtremePointsChannel(unittest.TestCase): + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index 5640e696fc..026f71200a 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -64,6 +64,7 @@ class TestAddExtremePointsChanneld(unittest.TestCase): + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py index 9fa0247115..2236056558 100644 --- a/tests/test_adjust_contrast.py +++ b/tests/test_adjust_contrast.py @@ -30,6 +30,7 @@ class TestAdjustContrast(NumpyImageTestCase2D): + @parameterized.expand(TESTS) def test_correct_results(self, gamma, invert_image, retain_stats): adjuster = AdjustContrast(gamma=gamma, invert_image=invert_image, retain_stats=retain_stats) diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py index 4a671ef7be..38eb001226 100644 --- a/tests/test_adjust_contrastd.py +++ b/tests/test_adjust_contrastd.py @@ -30,6 +30,7 @@ class TestAdjustContrastd(NumpyImageTestCase2D): + @parameterized.expand(TESTS) def test_correct_results(self, gamma, invert_image, retain_stats): adjuster = AdjustContrastd("img", gamma=gamma, invert_image=invert_image, retain_stats=retain_stats) diff --git a/tests/test_adn.py b/tests/test_adn.py index 27e23a08d3..327bf7b20c 100644 --- a/tests/test_adn.py +++ b/tests/test_adn.py @@ -59,6 +59,7 @@ class TestADN2D(TorchImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_adn_2d(self, args): adn = ADN(**args) @@ -73,6 +74,7 @@ def test_no_input(self): class TestADN3D(TorchImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_adn_3d(self, args): adn = ADN(**args) diff --git a/tests/test_adversarial_loss.py b/tests/test_adversarial_loss.py index 77880725ec..f7b9ae7eb0 100644 --- a/tests/test_adversarial_loss.py +++ b/tests/test_adversarial_loss.py @@ -39,6 +39,7 @@ class TestPatchAdversarialLoss(unittest.TestCase): + def get_input(self, shape, is_positive): """ Get tensor for the tests. The tensor is around (-1) or (+1), depending on diff --git a/tests/test_affine.py b/tests/test_affine.py index 9c2f4197a6..a08a22ae6f 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -167,6 +167,7 @@ class TestAffine(unittest.TestCase): + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): input_copy = deepcopy(input_data["img"]) @@ -199,6 +200,7 @@ def test_affine(self, input_param, input_data, expected_val): @unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.") class TestAffineConsistency(unittest.TestCase): + @parameterized.expand([[7], [8], [9]]) def test_affine_resize(self, s): """s""" diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index f3febbe0f3..2d89725bb7 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -135,6 +135,7 @@ class TestAffineGrid(unittest.TestCase): + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index 39dc609167..11464070e0 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -83,6 +83,7 @@ class TestNormTransform(unittest.TestCase): + @parameterized.expand(TEST_NORM_CASES) def test_norm_xform(self, input_shape, align_corners, expected, zero_centered=False): norm = normalize_transform( @@ -107,6 +108,7 @@ def test_norm_xform(self, input_shape, align_corners, expected, zero_centered=Fa class TestToNormAffine(unittest.TestCase): + @parameterized.expand(TEST_TO_NORM_AFFINE_CASES) def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expected, zero_centered=False): affine = torch.as_tensor(affine, device=torch.device("cpu:0"), dtype=torch.float32) @@ -130,28 +132,18 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners): class TestAffineTransform(unittest.TestCase): - def test_affine_shift(self): - affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) - image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) - out = AffineTransform(align_corners=False)(image, affine) - out = out.detach().cpu().numpy() - expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]] - np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) - def test_affine_shift_1(self): - affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]) - image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) - out = AffineTransform(align_corners=False)(image, affine) - out = out.detach().cpu().numpy() - expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]] - np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) - - def test_affine_shift_2(self): - affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) + @parameterized.expand( + [ + (torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]), + (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]), [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]), + (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]), [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]), + ] + ) + def test_affine_transforms(self, affine, expected): image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) out = AffineTransform(align_corners=False)(image, affine) out = out.detach().cpu().numpy() - expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom(self): diff --git a/tests/test_affined.py b/tests/test_affined.py index a35b35758a..94903ff8c7 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -168,6 +168,7 @@ class TestAffined(unittest.TestCase): + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): input_copy = deepcopy(input_data) diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 5707cf0452..99a177f395 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -126,6 +126,7 @@ class TestFCN(unittest.TestCase): + @parameterized.expand([TEST_CASE_FCN_1, TEST_CASE_FCN_2, TEST_CASE_FCN_3]) @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): @@ -136,6 +137,7 @@ def test_fcn_shape(self, input_param, input_shape, expected_shape): class TestFCNWithPretrain(unittest.TestCase): + @parameterized.expand([TEST_CASE_FCN_WITH_PRETRAIN_1, TEST_CASE_FCN_WITH_PRETRAIN_2]) @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): @@ -146,6 +148,7 @@ def test_fcn_shape(self, input_param, input_shape, expected_shape): class TestMCFCN(unittest.TestCase): + @parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = MCFCN(**input_param).to(device) @@ -155,6 +158,7 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape): class TestMCFCNWithPretrain(unittest.TestCase): + @parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = test_pretrained_networks(MCFCN, input_param, device) @@ -164,6 +168,7 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape): class TestAHNET(unittest.TestCase): + @parameterized.expand([TEST_CASE_AHNET_2D_1, TEST_CASE_AHNET_2D_2, TEST_CASE_AHNET_2D_3]) def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) @@ -192,6 +197,7 @@ def test_script(self): class TestAHNETWithPretrain(unittest.TestCase): + @parameterized.expand( [TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, TEST_CASE_AHNET_3D_WITH_PRETRAIN_3] ) diff --git a/tests/test_anchor_box.py b/tests/test_anchor_box.py index c29296e8ae..301ce78361 100644 --- a/tests/test_anchor_box.py +++ b/tests/test_anchor_box.py @@ -42,6 +42,7 @@ @SkipIfBeforePyTorchVersion((1, 11)) @unittest.skipUnless(has_torchvision, "Requires torchvision") class TestAnchorGenerator(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D) def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils") diff --git a/tests/test_apply.py b/tests/test_apply.py index 4784d46413..ca37e945ba 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -39,6 +39,7 @@ def single_2d_transform_cases(): class TestApply(unittest.TestCase): + def _test_apply_impl(self, tensor, pending_transforms, expected_shape): result = apply_pending(tensor, pending_transforms) self.assertListEqual(result[1], pending_transforms) diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py index 0de77bfb4d..e8db6da4b9 100644 --- a/tests/test_apply_filter.py +++ b/tests/test_apply_filter.py @@ -20,6 +20,7 @@ class ApplyFilterTestCase(unittest.TestCase): + def test_1d(self): a = torch.tensor([[list(range(10))]], dtype=torch.float) out = apply_filter(a, torch.tensor([-1, 0, 1]), stride=1) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index a3b78fc6e0..efc014a267 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -40,6 +40,7 @@ class TestCompose(Compose): + def __call__(self, input_, lazy): img = self.transforms[0](input_) metadata = img.meta @@ -77,6 +78,7 @@ def __call__(self, input_, lazy): class TestArrayDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, img_transform, label_transform, indices, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4)) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 8f88fb2928..51e1a5c0fd 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -27,6 +27,7 @@ class TestAsChannelLast(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, in_type, input_param, expected_shape): test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 16086b769c..aa51ab6056 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -27,6 +27,7 @@ class TestAsChannelLastd(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, in_type, input_param, expected_shape): test_data = { diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 2802c7d9ff..bf59752920 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -65,6 +65,7 @@ class TestAsDiscrete(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index ec394fc3af..ed1b3c5b3e 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -68,6 +68,7 @@ class TestAsDiscreted(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py index a614497bc9..6133d4839d 100644 --- a/tests/test_atss_box_matcher.py +++ b/tests/test_atss_box_matcher.py @@ -33,6 +33,7 @@ class TestATSS(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_atss(self, input_param, boxes, anchors, num_anchors_per_level, num_anchors_per_loc, expected_matches): matcher = ATSSMatcher(**input_param, debug=True) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index d5c67cee38..83f6cabc5e 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -20,6 +20,7 @@ class TestAttentionUnet(unittest.TestCase): + def test_attention_block(self): for dims in [2, 3]: block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6) diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py index 5964ddd6e9..6be33bf6ca 100644 --- a/tests/test_auto3dseg.py +++ b/tests/test_auto3dseg.py @@ -165,6 +165,7 @@ def __call__(self, data): class TestDataAnalyzer(unittest.TestCase): + def setUp(self): self.test_dir = tempfile.TemporaryDirectory() work_dir = self.test_dir.name @@ -366,7 +367,6 @@ def test_filename_case_analyzer(self): for batch_data in self.dataset: d = transform(batch_data[0]) assert DataStatsKeys.BY_CASE_IMAGE_PATH in d - assert DataStatsKeys.BY_CASE_IMAGE_PATH in d def test_filename_case_analyzer_image_only(self): analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH) diff --git a/tests/test_auto3dseg_bundlegen.py b/tests/test_auto3dseg_bundlegen.py index 1d2d6611bb..e7bf6820bc 100644 --- a/tests/test_auto3dseg_bundlegen.py +++ b/tests/test_auto3dseg_bundlegen.py @@ -107,6 +107,7 @@ def run_auto3dseg_before_bundlegen(test_path, work_dir): @SkipIfBeforePyTorchVersion((1, 11, 1)) @skip_if_quick class TestBundleGen(unittest.TestCase): + def setUp(self) -> None: set_determinism(0) self.test_dir = tempfile.TemporaryDirectory() diff --git a/tests/test_auto3dseg_ensemble.py b/tests/test_auto3dseg_ensemble.py index 367f66581c..7ac553cc0c 100644 --- a/tests/test_auto3dseg_ensemble.py +++ b/tests/test_auto3dseg_ensemble.py @@ -112,6 +112,7 @@ def create_sim_data(dataroot, sim_datalist, sim_dim, **kwargs): @SkipIfBeforePyTorchVersion((1, 11, 1)) @unittest.skipIf(not has_tb, "no tensorboard summary writer") class TestEnsembleBuilder(unittest.TestCase): + def setUp(self) -> None: set_determinism(0) self.test_dir = tempfile.TemporaryDirectory() diff --git a/tests/test_auto3dseg_hpo.py b/tests/test_auto3dseg_hpo.py index 0441116dc9..53d09defa0 100644 --- a/tests/test_auto3dseg_hpo.py +++ b/tests/test_auto3dseg_hpo.py @@ -79,6 +79,7 @@ def skip_if_no_optuna(obj): @SkipIfBeforePyTorchVersion((1, 11, 1)) @unittest.skipIf(not has_tb, "no tensorboard summary writer") class TestHPO(unittest.TestCase): + def setUp(self) -> None: self.test_dir = tempfile.TemporaryDirectory() test_path = self.test_dir.name @@ -154,6 +155,7 @@ def test_run_optuna(self) -> None: algo = algo_dict[AlgoKeys.ALGO] class OptunaGenLearningRate(OptunaGen): + def get_hyperparameters(self): return {"learning_rate": self.trial.suggest_float("learning_rate", 0.00001, 0.1)} @@ -179,7 +181,7 @@ def test_get_history(self) -> None: NNIGen().run_algo(obj_filename, self.work_dir) history = import_bundle_algo_history(self.work_dir, only_trained=True) - assert len(history) == 3 + assert len(history) == 1 def tearDown(self) -> None: self.test_dir.cleanup() diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 485049c2d1..6408f6a6d0 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -74,6 +74,7 @@ class TestAutoEncoder(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = AutoEncoder(**input_param).to(device) diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 448f1e8e9a..3cc671a1d0 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -18,11 +18,14 @@ from monai.networks import eval_mode from monai.networks.nets import AutoencoderKL +from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") +einops, has_einops = optional_import("einops") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -CASES = [ +CASES_NO_ATTENTION = [ [ { "spatial_dims": 2, @@ -33,11 +36,33 @@ "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, }, (1, 1, 16, 16), (1, 1, 16, 16), (1, 4, 4, 4), ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +CASES_ATTENTION = [ [ { "spatial_dims": 2, @@ -46,7 +71,7 @@ "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), - "num_res_blocks": (1, 1, 2), + "num_res_blocks": 1, "norm_num_groups": 4, }, (1, 1, 16, 16), @@ -61,7 +86,7 @@ "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), - "num_res_blocks": 1, + "num_res_blocks": (1, 1, 2), "norm_num_groups": 4, }, (1, 1, 16, 16), @@ -75,7 +100,7 @@ "out_channels": 1, "channels": (4, 4, 4), "latent_channels": 4, - "attention_levels": (False, False, True), + "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, }, @@ -106,11 +131,9 @@ "out_channels": 1, "channels": (4, 4, 4), "latent_channels": 4, - "attention_levels": (False, False, False), + "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, }, (1, 1, 16, 16), (1, 1, 16, 16), @@ -133,6 +156,11 @@ ], ] +if has_einops: + CASES = CASES_NO_ATTENTION + CASES_ATTENTION +else: + CASES = CASES_NO_ATTENTION + class TestAutoEncoderKL(unittest.TestCase): @parameterized.expand(CASES) diff --git a/tests/test_avg_merger.py b/tests/test_avg_merger.py index adef2a759a..7995d63271 100644 --- a/tests/test_avg_merger.py +++ b/tests/test_avg_merger.py @@ -137,6 +137,7 @@ class AvgMergerTests(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_0_DEFAULT_DTYPE, diff --git a/tests/test_barlow_twins_loss.py b/tests/test_barlow_twins_loss.py new file mode 100644 index 0000000000..81f4032e0c --- /dev/null +++ b/tests/test_barlow_twins_loss.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import BarlowTwinsLoss + +TEST_CASES = [ + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), + }, + 5.2562, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-4}, + { + "input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]), + "target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]), + }, + 5.0015, + ], + [ # shape: (4, 4), (4, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor( + [[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]] + ), + "target": torch.tensor( + [ + [0.0, 1.0, -1.0, 0.0], + [1 / 3, 0.0, -2 / 3, 1 / 3], + [-2 / 3, -1.0, 7 / 3, 1 / 3], + [1 / 3, 0.0, 1 / 3, -2 / 3], + ] + ), + }, + 1.4736, + ], +] + + +class TestBarlowTwinsLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + barlowtwinsloss = BarlowTwinsLoss(**input_param) + result = barlowtwinsloss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_batch_size(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2)), torch.ones((1, 2))) + + def test_with_cuda(self): + loss = BarlowTwinsLoss(lambd=5e-3) + i = torch.ones((2, 10)) + j = torch.ones((2, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4) + + def check_warning_raised(self): + with self.assertWarns(Warning): + BarlowTwinsLoss(lambd=5e-3, batch_size=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index 23e19dd536..770750851f 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -83,6 +83,7 @@ class TestBasicUNET(unittest.TestCase): + @parameterized.expand(CASES_1D + CASES_2D + CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_basic_unetplusplus.py b/tests/test_basic_unetplusplus.py index 19ed5977fd..6438b5e0d4 100644 --- a/tests/test_basic_unetplusplus.py +++ b/tests/test_basic_unetplusplus.py @@ -83,6 +83,7 @@ class TestBasicUNETPlusPlus(unittest.TestCase): + @parameterized.expand(CASES_1D + CASES_2D + CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index f29d4f256b..2e8ab32dbd 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -50,6 +50,7 @@ class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = BendingEnergyLoss(**input_param).forward(**input_data) diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py index da30d5d7de..e8a55e1f76 100644 --- a/tests/test_bilateral_approx_cpu.py +++ b/tests/test_bilateral_approx_cpu.py @@ -365,6 +365,7 @@ @skip_if_no_cpp_extension class BilateralFilterTestCaseCpuApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cpu_approx(self, test_case_description, sigmas, input, expected): # Params to determine the implementation to test diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py index b9be7d9ccf..4ad15d9646 100644 --- a/tests/test_bilateral_approx_cuda.py +++ b/tests/test_bilateral_approx_cuda.py @@ -366,6 +366,7 @@ @skip_if_no_cuda @skip_if_no_cpp_extension class BilateralFilterTestCaseCudaApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cuda_approx(self, test_case_description, sigmas, input, expected): # Skip this test diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py index 1a68dc8b4e..e13ede5bfd 100644 --- a/tests/test_bilateral_precise.py +++ b/tests/test_bilateral_precise.py @@ -366,6 +366,7 @@ @skip_if_no_cpp_extension @skip_if_quick class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cpu_precise(self, test_case_description, sigmas, input, expected): # Params to determine the implementation to test @@ -399,6 +400,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expec @skip_if_no_cuda @skip_if_no_cpp_extension class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cuda_precise(self, test_case_description, sigmas, input, expected): # Skip this test diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py index 9814a5a3f8..700ae1fe58 100644 --- a/tests/test_blend_images.py +++ b/tests/test_blend_images.py @@ -44,6 +44,7 @@ def get_alpha(img): @skipUnless(has_matplotlib, "Matplotlib required") class TestBlendImages(unittest.TestCase): + @parameterized.expand(TESTS) def test_blend(self, image, label, alpha): blended = blend_images(image, label, alpha) diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index b9c232e2d2..b879fa6093 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -28,6 +28,7 @@ class TestBoundingRect(unittest.TestCase): + def setUp(self): monai.utils.set_determinism(1) diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 248a0a8e47..96435036b1 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -28,6 +28,7 @@ class TestBoundingRectD(unittest.TestCase): + def setUp(self): monai.utils.set_determinism(1) diff --git a/tests/test_box_coder.py b/tests/test_box_coder.py index 5835341139..75ff650d6c 100644 --- a/tests/test_box_coder.py +++ b/tests/test_box_coder.py @@ -21,6 +21,7 @@ class TestBoxTransform(unittest.TestCase): + def test_value(self): box_coder = BoxCoder(weights=[1, 1, 1, 1, 1, 1]) test_dtype = [torch.float32, torch.float16] diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index e114f8869f..e99f95fa32 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -79,6 +79,7 @@ class TestBoxTransform(unittest.TestCase): + @parameterized.expand(TESTS_2D_mask) def test_value_2d_mask(self, mask, expected_box_label): box_label = convert_mask_to_box(mask) diff --git a/tests/test_box_utils.py b/tests/test_box_utils.py index c4fefb5a98..3c05efe0d0 100644 --- a/tests/test_box_utils.py +++ b/tests/test_box_utils.py @@ -140,6 +140,7 @@ class TestCreateBoxList(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_data, mode2, expected_box, expected_area): expected_box = convert_data_type(expected_box, dtype=np.float32)[0] diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index d9b3bedab2..8f376a06d5 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -32,6 +32,7 @@ @skip_if_windows class TestCKPTExport(unittest.TestCase): + def setUp(self): self.device = os.environ.get("CUDA_VISIBLE_DEVICES") if not self.device: diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index fa96c6f28d..89fbe5e8b2 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -93,6 +93,7 @@ class TestDownload(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @skip_if_quick def test_github_download_bundle(self, bundle_name, version): @@ -192,6 +193,7 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve @skip_if_no_cuda class TestLoad(unittest.TestCase): + @parameterized.expand([TEST_CASE_7]) @skip_if_quick def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): @@ -336,6 +338,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, class TestDownloadLargefiles(unittest.TestCase): + @parameterized.expand([TEST_CASE_10]) @skip_if_quick def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val): diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index 88bfed758a..605b3945bb 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -45,6 +45,7 @@ @skip_if_windows @SkipIfNoModule("requests") class TestGetBundleData(unittest.TestCase): + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) @skip_if_quick def test_get_all_bundles_list(self, params): diff --git a/tests/test_bundle_init_bundle.py b/tests/test_bundle_init_bundle.py index 08f921da01..eb831093d5 100644 --- a/tests/test_bundle_init_bundle.py +++ b/tests/test_bundle_init_bundle.py @@ -23,6 +23,7 @@ @skip_if_windows class TestBundleInit(unittest.TestCase): + def test_bundle(self): with tempfile.TemporaryDirectory() as tempdir: net = UNet(2, 1, 1, [4, 8], [2]) diff --git a/tests/test_bundle_onnx_export.py b/tests/test_bundle_onnx_export.py index ffd5fa636d..ee22d7caef 100644 --- a/tests/test_bundle_onnx_export.py +++ b/tests/test_bundle_onnx_export.py @@ -29,6 +29,7 @@ @SkipIfNoModule("onnx") @SkipIfBeforePyTorchVersion((1, 10)) class TestONNXExport(unittest.TestCase): + def setUp(self): self.device = os.environ.get("CUDA_VISIBLE_DEVICES") if not self.device: diff --git a/tests/test_bundle_push_to_hf_hub.py b/tests/test_bundle_push_to_hf_hub.py index 375c5d81e8..39368c6f40 100644 --- a/tests/test_bundle_push_to_hf_hub.py +++ b/tests/test_bundle_push_to_hf_hub.py @@ -28,6 +28,7 @@ class TestPushToHuggingFaceHub(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) @skip_if_quick @skipUnless(has_huggingface_hub, "Requires `huggingface_hub` package.") diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 72743f5fcb..47034852ef 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -48,6 +48,7 @@ @skip_if_no_cuda @skip_if_quick class TestTRTExport(unittest.TestCase): + def setUp(self): self.device = os.environ.get("CUDA_VISIBLE_DEVICES") if not self.device: diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 181c08475c..47c534f3b6 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -51,6 +51,7 @@ @skip_if_windows class TestLoadBundleConfig(unittest.TestCase): + def setUp(self): self.bundle_dir = tempfile.TemporaryDirectory() self.dir_name = os.path.join(self.bundle_dir.name, "TestBundle") @@ -134,6 +135,7 @@ def test_load_config_ts(self): class TestPPrintEdges(unittest.TestCase): + def test_str(self): self.assertEqual(pprint_edges("", 0), "''") self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}") diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index 0701e905b9..f6c2192621 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -28,6 +28,7 @@ @skip_if_windows class TestVerifyMetaData(unittest.TestCase): + def setUp(self): self.config = testing_data_config("configs", "test_meta_file") download_url_or_skip_test( diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 6f516fdd48..f55fdd597b 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -28,6 +28,7 @@ @skip_if_windows class TestVerifyNetwork(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_verify(self, meta_file, config_file): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 4291eedf3f..9a276b577f 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -35,8 +35,11 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] +TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."] + class TestBundleWorkflow(unittest.TestCase): + def setUp(self): self.data_dir = tempfile.mkdtemp() self.expected_shape = (128, 128, 128) @@ -102,6 +105,16 @@ def test_inference_config(self, config_file): ) self._test_inferer(inferer) + # test property path + inferer = ConfigWorkflow( + config_file=config_file, + properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"), + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + **override, + ) + self._test_inferer(inferer) + self.assertEqual(inferer.workflow_type, None) + @parameterized.expand([TEST_CASE_3]) def test_train_config(self, config_file): # test standard MONAI model-zoo config workflow @@ -143,8 +156,14 @@ def test_train_config(self, config_file): def test_non_config(self): # test user defined python style workflow inferer = NonConfigWorkflow(self.filename, self.data_dir) + self.assertEqual(inferer.meta_file, None) self._test_inferer(inferer) + @parameterized.expand([TEST_CASE_NON_CONFIG_WRONG_LOG]) + def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_error): + with self.assertRaisesRegex(FileNotFoundError, expected_error): + NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index dcae5fdce1..dbb1b8f8f1 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -39,6 +39,7 @@ class TestCacheDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index c3fc2cc362..6a01a82512 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -30,6 +30,7 @@ class TestCacheDatasetParallel(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, num_workers, dataset_size, transform): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4)) diff --git a/tests/test_cachedataset_persistent_workers.py b/tests/test_cachedataset_persistent_workers.py index e60862238d..78092906c6 100644 --- a/tests/test_cachedataset_persistent_workers.py +++ b/tests/test_cachedataset_persistent_workers.py @@ -18,6 +18,7 @@ class TestTransformsWCacheDatasetAndPersistentWorkers(unittest.TestCase): + def test_duplicate_transforms(self): data = [{"img": create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0)[0]} for _ in range(2)] diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py index d50fe4f8dd..90e86c2eb0 100644 --- a/tests/test_cachentransdataset.py +++ b/tests/test_cachentransdataset.py @@ -34,6 +34,7 @@ class TestCacheNTransDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_n_trans(self, transform, expected_shape): data_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float) diff --git a/tests/test_call_dist.py b/tests/test_call_dist.py index 0621824b65..503cb5e792 100644 --- a/tests/test_call_dist.py +++ b/tests/test_call_dist.py @@ -17,6 +17,7 @@ class DistributedCallTest(DistTestCase): + def test_constructor(self): with self.assertRaises(ValueError): DistCall(nnodes=1, nproc_per_node=0) diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 6dd994120c..035260804e 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -37,6 +37,7 @@ class TestCastToType(unittest.TestCase): + @parameterized.expand(TESTS) def test_type(self, out_dtype, input_data, expected_type): result = CastToType(dtype=out_dtype)(input_data) diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index 687deeda1d..81e17117a9 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -53,6 +53,7 @@ class TestCastToTyped(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type(self, input_param, input_data, expected_type): result = CastToTyped(**input_param)(input_data) diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index 2d8c57fd68..77dd172378 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -34,6 +34,7 @@ class TestChannelPad(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelPad(**input_param) diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index bb3d0ff12e..263c18703c 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -32,6 +32,7 @@ class TestCheckMD5(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_result(self, md5_value, t, expected_result): test_image = np.ones((5, 5, 3)) diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py index efbe5a95fb..2b5c17a1ec 100644 --- a/tests/test_check_missing_files.py +++ b/tests/test_check_missing_files.py @@ -23,6 +23,7 @@ class TestCheckMissingFiles(unittest.TestCase): + def test_content(self): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index e7dd7abfe5..a7377dac16 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -82,6 +82,7 @@ class TestClassesToIndices(unittest.TestCase): + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, label, image, expected_indices): indices = ClassesToIndices(**input_args)(label, image) diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py index 7a34cc06b4..dead1ae753 100644 --- a/tests/test_classes_to_indicesd.py +++ b/tests/test_classes_to_indicesd.py @@ -97,6 +97,7 @@ class TestClassesToIndicesd(unittest.TestCase): + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, input_data, expected_indices): result = ClassesToIndicesd(**input_args)(input_data) diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py index 071bd20d6c..14d3575e3b 100644 --- a/tests/test_cldice_loss.py +++ b/tests/test_cldice_loss.py @@ -23,6 +23,7 @@ class TestclDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, y_pred_data, expected_val): loss = SoftclDiceLoss() diff --git a/tests/test_clip_intensity_percentiles.py b/tests/test_clip_intensity_percentiles.py new file mode 100644 index 0000000000..af157446f6 --- /dev/null +++ b/tests/test_clip_intensity_percentiles.py @@ -0,0 +1,185 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import ClipIntensityPercentiles +from monai.transforms.utils import soft_clip +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose + + +class TestClipIntensityPercentiles2D(NumpyImageTestCase2D): + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_two_sided(self, p): + hard_clipper = ClipIntensityPercentiles(upper=95, lower=5) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (5, 95)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_high(self, p): + hard_clipper = ClipIntensityPercentiles(upper=95, lower=None) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (0, 95)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_low(self, p): + hard_clipper = ClipIntensityPercentiles(upper=None, lower=5) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (5, 100)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_two_sided(self, p): + soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + lower, upper = percentile(im, (5, 95)) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_high(self, p): + soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + upper = percentile(im, 95) + expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) + # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=5e-5, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_low(self, p): + soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + lower = percentile(im, 5) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True) + im = p(self.imt) + result = clipper(im) + for i, c in enumerate(im): + lower, upper = percentile(c, (5, 95)) + expected = clip(c, lower, upper) + assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + def test_ill_sharpness_factor(self): + with self.assertRaises(ValueError): + ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=0.0) + + def test_ill_lower_percentile(self): + with self.assertRaises(ValueError): + ClipIntensityPercentiles(upper=None, lower=-1) + + def test_ill_upper_percentile(self): + with self.assertRaises(ValueError): + ClipIntensityPercentiles(upper=101, lower=None) + + def test_ill_percentiles(self): + with self.assertRaises(ValueError): + ClipIntensityPercentiles(upper=95, lower=96) + + def test_ill_both_none(self): + with self.assertRaises(ValueError): + ClipIntensityPercentiles(upper=None, lower=None) + + +class TestClipIntensityPercentiles3D(NumpyImageTestCase3D): + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_two_sided(self, p): + hard_clipper = ClipIntensityPercentiles(upper=95, lower=5) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (5, 95)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_high(self, p): + hard_clipper = ClipIntensityPercentiles(upper=95, lower=None) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (0, 95)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_low(self, p): + hard_clipper = ClipIntensityPercentiles(upper=None, lower=5) + im = p(self.imt) + result = hard_clipper(im) + lower, upper = percentile(im, (5, 100)) + expected = clip(im, lower, upper) + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_two_sided(self, p): + soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + lower, upper = percentile(im, (5, 95)) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_high(self, p): + soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + upper = percentile(im, 95) + expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) + # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=5e-5, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_low(self, p): + soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper(im) + lower = percentile(im, 5) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True) + im = p(self.imt) + result = clipper(im) + for i, c in enumerate(im): + lower, upper = percentile(c, (5, 95)) + expected = clip(c, lower, upper) + assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_clip_intensity_percentilesd.py b/tests/test_clip_intensity_percentilesd.py new file mode 100644 index 0000000000..fa727b6adb --- /dev/null +++ b/tests/test_clip_intensity_percentilesd.py @@ -0,0 +1,205 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import ClipIntensityPercentilesd +from monai.transforms.utils import soft_clip +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose + + +class TestClipIntensityPercentilesd2D(NumpyImageTestCase2D): + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_two_sided(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (5, 95)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_high(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (0, 95)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_low(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (5, 100)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_two_sided(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + lower, upper = percentile(im, (5, 95)) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_high(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + upper = percentile(im, 95) + expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) + # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=5e-5, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_low(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + lower = percentile(im, 5) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + key = "img" + clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True) + im = p(self.imt) + result = clipper({key: im}) + for i, c in enumerate(im): + lower, upper = percentile(c, (5, 95)) + expected = clip(c, lower, upper) + assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + def test_ill_sharpness_factor(self): + key = "img" + with self.assertRaises(ValueError): + ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=0.0) + + def test_ill_lower_percentile(self): + key = "img" + with self.assertRaises(ValueError): + ClipIntensityPercentilesd(keys=[key], upper=None, lower=-1) + + def test_ill_upper_percentile(self): + key = "img" + with self.assertRaises(ValueError): + ClipIntensityPercentilesd(keys=[key], upper=101, lower=None) + + def test_ill_percentiles(self): + key = "img" + with self.assertRaises(ValueError): + ClipIntensityPercentilesd(keys=[key], upper=95, lower=96) + + def test_ill_both_none(self): + key = "img" + with self.assertRaises(ValueError): + ClipIntensityPercentilesd(keys=[key], upper=None, lower=None) + + +class TestClipIntensityPercentilesd3D(NumpyImageTestCase3D): + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_two_sided(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (5, 95)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_high(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (0, 95)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_hard_clipping_one_sided_low(self, p): + key = "img" + hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5) + im = p(self.imt) + result = hard_clipper({key: im}) + lower, upper = percentile(im, (5, 100)) + expected = clip(im, lower, upper) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_two_sided(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + lower, upper = percentile(im, (5, 95)) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_high(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + upper = percentile(im, 95) + expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) + # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=5e-5, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_soft_clipping_one_sided_low(self, p): + key = "img" + soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0) + im = p(self.imt) + result = soft_clipper({key: im}) + lower = percentile(im, 5) + expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) + # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + key = "img" + clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True) + im = p(self.imt) + result = clipper({key: im}) + for i, c in enumerate(im): + lower, upper = percentile(c, (5, 95)) + expected = clip(c, lower, upper) + assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_complex_utils.py b/tests/test_complex_utils.py index 77eaa924a2..fdcee4babe 100644 --- a/tests/test_complex_utils.py +++ b/tests/test_complex_utils.py @@ -51,6 +51,7 @@ class TestMRIUtils(unittest.TestCase): + @parameterized.expand(TESTS) def test_to_tensor_complex(self, test_data, expected_shape): result = convert_to_tensor_complex(test_data) diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py index 3b54a13706..9378fc159d 100644 --- a/tests/test_component_locator.py +++ b/tests/test_component_locator.py @@ -21,6 +21,7 @@ class TestComponentLocator(unittest.TestCase): + def test_locate(self): locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) # test init mapping table and get the module path of component diff --git a/tests/test_component_store.py b/tests/test_component_store.py index 614f387754..424eceb3d1 100644 --- a/tests/test_component_store.py +++ b/tests/test_component_store.py @@ -17,6 +17,7 @@ class TestComponentStore(unittest.TestCase): + def setUp(self): self.cs = ComponentStore("TestStore", "I am a test store, please ignore") diff --git a/tests/test_compose.py b/tests/test_compose.py index a1952b102f..3c53ac4a22 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -39,6 +39,7 @@ def data_from_keys(keys, h, w): class _RandXform(Randomizable): + def randomize(self): self.val = self.R.random_sample() @@ -48,12 +49,14 @@ def __call__(self, __unused): class TestCompose(unittest.TestCase): + def test_empty_compose(self): c = mt.Compose() i = 1 self.assertEqual(c(i), 1) def test_non_dict_compose(self): + def a(i): return i + "a" @@ -64,6 +67,7 @@ def b(i): self.assertEqual(c(""), "abab") def test_dict_compose(self): + def a(d): d = dict(d) d["a"] += 1 @@ -82,6 +86,7 @@ def b(d): self.assertDictEqual(execute_compose(data, transforms), expected) def test_list_dict_compose(self): + def a(d): # transform to handle dict data d = dict(d) d["a"] += 1 @@ -109,6 +114,7 @@ def c(d): # transform to handle dict data self.assertDictEqual(item, expected) def test_non_dict_compose_with_unpack(self): + def a(i, i2): return i + "a", i2 + "a2" @@ -122,6 +128,7 @@ def b(i, i2): self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected) def test_list_non_dict_compose_with_unpack(self): + def a(i, i2): return i + "a", i2 + "a2" @@ -135,6 +142,7 @@ def b(i, i2): self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) def test_list_dict_compose_no_map(self): + def a(d): # transform to handle dict data d = dict(d) d["a"] += 1 @@ -163,6 +171,7 @@ def c(d): # transform to handle dict data self.assertDictEqual(item, expected) def test_random_compose(self): + class _Acc(Randomizable): self.rand = 0.0 @@ -182,7 +191,9 @@ def __call__(self, data): self.assertAlmostEqual(c(1), 1.90734751) def test_randomize_warn(self): + class _RandomClass(Randomizable): + def randomize(self, foo1, foo2): pass @@ -267,6 +278,7 @@ def test_backwards_compatible_imports(self): class TestComposeExecute(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) def test_compose_execute_equivalence(self, keys, pipeline): data = data_from_keys(keys, 12, 16) @@ -657,8 +669,10 @@ def test_compose_lazy_on_call_with_logging(self, compose_type, pipeline, lazy_on class TestOps: + @staticmethod def concat(value): + def _inner(data): return data + value @@ -666,6 +680,7 @@ def _inner(data): @staticmethod def concatd(value): + def _inner(data): return {k: v + value for k, v in data.items()} @@ -673,6 +688,7 @@ def _inner(data): @staticmethod def concata(value): + def _inner(data1, data2): return data1 + value, data2 + value @@ -688,6 +704,7 @@ def _inner(data1, data2): class TestComposeExecuteWithFlags(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES) def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): expected = mt.Compose(pipeline, **flags)(data) @@ -699,18 +716,19 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): for k in actual.keys(): self.assertEqual(expected[k], actual[k]) else: - self.assertTrue(expected, actual) + self.assertEqual(expected, actual) p = deepcopy(pipeline) actual = execute_compose(execute_compose(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags) if isinstance(actual, dict): for k in actual.keys(): - self.assertTrue(expected[k], actual[k]) + self.assertEqual(expected[k], actual[k]) else: - self.assertTrue(expected, actual) + self.assertEqual(expected, actual) class TestComposeCallableInput(unittest.TestCase): + def test_value_error_when_not_sequence(self): data = torch.tensor(np.random.randn(1, 5, 5)) diff --git a/tests/test_compose_get_number_conversions.py b/tests/test_compose_get_number_conversions.py index 664558d9cd..2623bab69c 100644 --- a/tests/test_compose_get_number_conversions.py +++ b/tests/test_compose_get_number_conversions.py @@ -38,6 +38,7 @@ def _apply(x, fn): class Load(Transform): + def __init__(self, as_tensor): self.fn = lambda _: PT_ARR if as_tensor else NP_ARR @@ -46,26 +47,31 @@ def __call__(self, x): class N(Transform): + def __call__(self, x): return _apply(x, convert_to_numpy) class T(Transform): + def __call__(self, x): return _apply(x, convert_to_tensor) class NT(Transform): + def __call__(self, x): return _apply(x, lambda x: x) class TCPU(Transform): + def __call__(self, x): return _apply(x, lambda x: convert_to_tensor(x).cpu()) class TGPU(Transform): + def __call__(self, x): return _apply(x, lambda x: convert_to_tensor(x).cuda()) @@ -103,6 +109,7 @@ def __call__(self, x): class TestComposeNumConversions(unittest.TestCase): + @parameterized.expand(TESTS) def test_get_number_of_conversions(self, transforms, is_dict, input, expected): input = input if not is_dict else {KEY: input, "Other": NP_ARR} diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index e0a92aec67..248f16a7fe 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -220,6 +220,7 @@ class TestConfusionMatrix(unittest.TestCase): + @parameterized.expand([TEST_CASE_CONFUSION_MATRIX]) def test_value(self, input_data, expected_value): # include or ignore background diff --git a/tests/test_compute_f_beta.py b/tests/test_compute_f_beta.py index c8ed5aa887..43ebb6a6d5 100644 --- a/tests/test_compute_f_beta.py +++ b/tests/test_compute_f_beta.py @@ -15,6 +15,7 @@ import numpy as np import torch +from parameterized import parameterized from monai.metrics import FBetaScore from tests.utils import assert_allclose @@ -23,6 +24,7 @@ class TestFBetaScore(unittest.TestCase): + def test_expecting_success_and_device(self): metric = FBetaScore() y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device) @@ -32,26 +34,21 @@ def test_expecting_success_and_device(self): assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6) np.testing.assert_equal(result.device, y_pred.device) - def test_expecting_success2(self): - metric = FBetaScore(beta=0.5) - metric( - y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) - ) - assert_allclose(metric.aggregate()[0], torch.Tensor([0.609756]), atol=1e-6, rtol=1e-6) - - def test_expecting_success3(self): - metric = FBetaScore(beta=2) - metric( - y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) - ) - assert_allclose(metric.aggregate()[0], torch.Tensor([0.862069]), atol=1e-6, rtol=1e-6) - - def test_denominator_is_zero(self): - metric = FBetaScore(beta=2) - metric( - y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) - ) - assert_allclose(metric.aggregate()[0], torch.Tensor([0.0]), atol=1e-6, rtol=1e-6) + @parameterized.expand( + [ + (0.5, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.609756])), # success_beta_0_5 + (2, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.862069])), # success_beta_2 + ( + 2, # success_beta_2, denominator_zero + torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), + torch.Tensor([0.0]), + ), + ] + ) + def test_success_and_zero(self, beta, y, expected_score): + metric = FBetaScore(beta=beta) + metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y) + assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6) def test_number_of_dimensions_less_than_2_should_raise_error(self): metric = FBetaScore() diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py index 1c7c3273fe..bd867f5296 100644 --- a/tests/test_compute_fid_metric.py +++ b/tests/test_compute_fid_metric.py @@ -24,6 +24,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy") class TestFIDMetric(unittest.TestCase): + def test_results(self): x = torch.Tensor([[1, 2], [1, 2], [1, 2]]) y = torch.Tensor([[2, 2], [1, 2], [1, 2]]) diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py index 0a48dc099a..4dc0507366 100644 --- a/tests/test_compute_froc.py +++ b/tests/test_compute_froc.py @@ -111,6 +111,7 @@ class TestComputeFpTp(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_value(self, input_data, expected_fp, expected_tp, expected_num): fp_probs, tp_probs, num_tumors = compute_fp_tp_probs(**input_data) @@ -120,6 +121,7 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num): class TestComputeFpTpNd(unittest.TestCase): + @parameterized.expand([TEST_CASE_ND_1, TEST_CASE_ND_2]) def test_value(self, input_data, expected_fp, expected_tp, expected_num): fp_probs, tp_probs, num_tumors = compute_fp_tp_probs_nd(**input_data) @@ -129,6 +131,7 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num): class TestComputeFrocScore(unittest.TestCase): + @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_value(self, input_data, thresholds, expected_score): fps_per_image, total_sensitivity = compute_froc_curve_data(**input_data) diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index ab3d012c97..e04444e988 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -119,6 +119,7 @@ class TestComputeGeneralizedDiceScore(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_device(self, input_data, _expected_value): result = compute_generalized_dice(**input_data) diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py index 50598cb57b..bbd5230f04 100644 --- a/tests/test_compute_ho_ver_maps.py +++ b/tests/test_compute_ho_ver_maps.py @@ -62,6 +62,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class ComputeHoVerMapsTests(unittest.TestCase): + @parameterized.expand(TESTS) def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): input_image = in_type(mask) diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py index 27bb57988c..7b5ac0d9d7 100644 --- a/tests/test_compute_ho_ver_maps_d.py +++ b/tests/test_compute_ho_ver_maps_d.py @@ -63,6 +63,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class ComputeHoVerMapsDictTests(unittest.TestCase): + @parameterized.expand(TESTS) def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): hv_key = list(hv_mask.keys())[0] diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 46e1d67b1b..aae15483b5 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -252,6 +252,7 @@ class TestComputeMeanDice(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) def test_value(self, input_data, expected_value): result = compute_dice(**input_data) diff --git a/tests/test_compute_meaniou.py b/tests/test_compute_meaniou.py index d39edaa6f3..0b7a2bbce2 100644 --- a/tests/test_compute_meaniou.py +++ b/tests/test_compute_meaniou.py @@ -187,6 +187,7 @@ class TestComputeMeanIoU(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) def test_value(self, input_data, expected_value): result = compute_iou(**input_data) diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py index d1b69b3dfe..96b5cbc089 100644 --- a/tests/test_compute_mmd_metric.py +++ b/tests/test_compute_mmd_metric.py @@ -36,6 +36,7 @@ class TestMMDMetric(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_results(self, input_param, input_data, expected_val): metric = MMDMetric(**input_param) diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py index 4ebc5b7935..3df8026c2b 100644 --- a/tests/test_compute_multiscalessim_metric.py +++ b/tests/test_compute_multiscalessim_metric.py @@ -20,6 +20,7 @@ class TestMultiScaleSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): set_determinism(0) preds = torch.abs(torch.randn(1, 1, 64, 64)) diff --git a/tests/test_compute_panoptic_quality.py b/tests/test_compute_panoptic_quality.py index a5858e91d1..a916ea32b2 100644 --- a/tests/test_compute_panoptic_quality.py +++ b/tests/test_compute_panoptic_quality.py @@ -92,6 +92,7 @@ @SkipIfNoModule("scipy.optimize") class TestPanopticQualityMetric(unittest.TestCase): + @parameterized.expand([TEST_FUNC_CASE_1, TEST_FUNC_CASE_2, TEST_FUNC_CASE_3, TEST_FUNC_CASE_4]) def test_value(self, input_params, expected_value): result = compute_panoptic_quality(**input_params) diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index b0fde3afe9..a8b7f03e47 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -45,6 +45,7 @@ def psnrmetric_np(max_val, y_pred, y): class TestRegressionMetrics(unittest.TestCase): + def test_shape_reduction(self): set_determinism(seed=123) device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 2f080c76cb..f2cb816db4 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -100,6 +100,7 @@ class TestComputeROCAUC(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_1, diff --git a/tests/test_compute_variance.py b/tests/test_compute_variance.py index 8eaac10a6c..486a1e9f6f 100644 --- a/tests/test_compute_variance.py +++ b/tests/test_compute_variance.py @@ -109,6 +109,7 @@ class TestComputeVariance(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, input_data, expected_value): result = compute_variance(**input_data) diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 322a95d7df..64c5d6e255 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -22,6 +22,7 @@ class TestConcatItemsd(unittest.TestCase): + def test_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") input_data = { diff --git a/tests/test_config_item.py b/tests/test_config_item.py index cb1e7ad552..4909ecf6be 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -37,7 +37,7 @@ TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] # test non-monai modules and excludes TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam] -TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial] +TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial] # test args contains "name" field TEST_CASE_8 = [ {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, @@ -52,6 +52,7 @@ class TestConfigItem(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_item(self, test_input, expected): item = ConfigItem(config=test_input) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 63254e7336..cc890a0522 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -72,6 +72,7 @@ def case_pdb_inst(sarg=None): class TestClass: + @staticmethod def compute(a, b, func=lambda x, y: x + y): return func(a, b) @@ -126,6 +127,7 @@ def __call__(self, a, b): class TestConfigParser(unittest.TestCase): + def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -181,7 +183,7 @@ def test_function(self, config): parser = ConfigParser(config=config, globals={"TestClass": TestClass}) for id in config: if id in ("compute", "cls_compute"): - parser[f"{id}#_mode_"] = "partial" + parser[f"{id}#_mode_"] = "callable" func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content) if id == "error_func": @@ -277,7 +279,7 @@ def test_lambda_reference(self): def test_non_str_target(self): configs = { - "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"}, + "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"}, "model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2}, } self.assertTrue(callable(ConfigParser(config=configs).fwd)) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..64efe3b168 --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + def test_complex_valued_inverse(self): + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py index 4cafa0a905..21a9e76417 100644 --- a/tests/test_contrastive_loss.py +++ b/tests/test_contrastive_loss.py @@ -55,6 +55,7 @@ class TestContrastiveLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): contrastiveloss = ContrastiveLoss(**input_param) diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index c3e4490ffe..b95539f4b7 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -77,6 +77,7 @@ class TestTensor(torch.Tensor): class TestConvertDataType(unittest.TestCase): + @parameterized.expand(TESTS) def test_convert_data_type(self, in_image, im_out, out_dtype, safe): converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), dtype=out_dtype, safe=safe) diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 78c3c90688..98bbea1ebf 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -48,6 +48,7 @@ class TestConvertToMultiChannel(unittest.TestCase): + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 351adddb13..e482770497 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -26,6 +26,7 @@ class TestConvertToMultiChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE]) def test_type_shape(self, keys, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) diff --git a/tests/test_convert_to_onnx.py b/tests/test_convert_to_onnx.py index 7560c98703..798c510800 100644 --- a/tests/test_convert_to_onnx.py +++ b/tests/test_convert_to_onnx.py @@ -12,6 +12,7 @@ from __future__ import annotations import itertools +import platform import unittest import torch @@ -29,6 +30,12 @@ TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False])) TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True])) +ON_AARCH64 = platform.machine() == "aarch64" +if ON_AARCH64: + rtol, atol = 1e-1, 1e-2 +else: + rtol, atol = 1e-3, 1e-4 + onnx, _ = optional_import("onnx") @@ -36,6 +43,7 @@ @SkipIfBeforePyTorchVersion((1, 9)) @skip_if_quick class TestConvertToOnnx(unittest.TestCase): + @parameterized.expand(TESTS) def test_unet(self, device, use_trace, use_ort): if use_ort: @@ -55,8 +63,8 @@ def test_unet(self, device, use_trace, use_ort): device=device, use_ort=use_ort, use_trace=use_trace, - rtol=1e-3, - atol=1e-4, + rtol=rtol, + atol=atol, ) else: # https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182 @@ -71,8 +79,8 @@ def test_unet(self, device, use_trace, use_ort): device=device, use_ort=use_ort, use_trace=use_trace, - rtol=1e-3, - atol=1e-4, + rtol=rtol, + atol=atol, ) self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) @@ -106,8 +114,8 @@ def test_seg_res_net(self, device, use_ort): device=device, use_ort=use_ort, use_trace=True, - rtol=1e-3, - atol=1e-4, + rtol=rtol, + atol=atol, ) self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index 0b8e9a8141..c78b8e78c0 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -22,6 +22,7 @@ class TestConvertToTorchScript(unittest.TestCase): + def test_value(self): model = UNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 108ed66f31..5579539764 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -39,6 +39,7 @@ @skip_if_no_cuda @skip_if_quick class TestConvertToTRT(unittest.TestCase): + def setUp(self): self.gpu_device = torch.cuda.current_device() diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index 1311401f1d..77bc12770f 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -18,6 +18,7 @@ class TestConvolution2D(TorchImageTestCase2D): + def test_conv1(self): conv = Convolution(2, self.input_channels, self.output_channels) out = conv(self.imt) @@ -69,6 +70,7 @@ def test_transpose2(self): class TestConvolution3D(TorchImageTestCase3D): + def test_conv1(self): conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.1, adn_ordering="DAN") out = conv(self.imt) @@ -126,6 +128,7 @@ def test_transpose2(self): class TestResidualUnit2D(TorchImageTestCase2D): + def test_conv_only1(self): conv = ResidualUnit(2, 1, self.output_channels) out = conv(self.imt) diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index ff4799a094..a78e08897b 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -32,6 +32,7 @@ class TestCopyItemsd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_numpy_values(self, keys, times, names): input_data = {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[3, 4], [4, 5]])} diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index 2e7513b234..26b01d930a 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -22,6 +22,7 @@ class _TestModelOne(torch.nn.Module): + def __init__(self, n_n, n_m, n_class): super().__init__() self.layer = torch.nn.Linear(n_n, n_m) @@ -34,6 +35,7 @@ def forward(self, x): class _TestModelTwo(torch.nn.Module): + def __init__(self, n_n, n_m, n_d, n_class): super().__init__() self.layer = torch.nn.Linear(n_n, n_m) @@ -55,6 +57,7 @@ def forward(self, x): class TestModuleState(unittest.TestCase): + def tearDown(self): set_determinism(None) diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py index d2a95bf684..82b0b93b53 100644 --- a/tests/test_correct_crop_centers.py +++ b/tests/test_correct_crop_centers.py @@ -23,6 +23,7 @@ class TestCorrectCropCenters(unittest.TestCase): + @parameterized.expand(TESTS) def test_torch(self, spatial_size, centers, label_spatial_shape): result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape) diff --git a/tests/test_create_cross_validation_datalist.py b/tests/test_create_cross_validation_datalist.py index d05a94f59e..0e80be1cd0 100644 --- a/tests/test_create_cross_validation_datalist.py +++ b/tests/test_create_cross_validation_datalist.py @@ -20,6 +20,7 @@ class TestCreateCrossValidationDatalist(unittest.TestCase): + def test_content(self): with tempfile.TemporaryDirectory() as tempdir: datalist = [] diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 2b5890a777..4910a10470 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -28,6 +28,7 @@ class TestCreateGrid(unittest.TestCase): + def test_create_grid(self): with self.assertRaisesRegex(TypeError, ""): create_grid(None) @@ -168,6 +169,7 @@ def test_assert(func, params, expected): class TestCreateAffine(unittest.TestCase): + def test_create_rotate(self): with self.assertRaisesRegex(TypeError, ""): create_rotate(2, None) diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index e29a4d69eb..a7ae0ff2df 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -495,6 +495,7 @@ @skip_if_no_cpp_extension class CRFTestCaseCpu(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): # Create input tensors diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index 8529e2e6de..d5329aab15 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -496,6 +496,7 @@ @skip_if_no_cpp_extension @skip_if_no_cuda class CRFTestCaseCuda(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): # Create input tensors diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 4435b128ba..f63cb3e8b0 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -99,6 +99,7 @@ class TestCropForeground(unittest.TestCase): + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, arguments, image, expected_data, _): cropper = CropForeground(**arguments) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index 776776f6c5..92954aa81e 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -158,6 +158,7 @@ class TestCropForegroundd(unittest.TestCase): + @parameterized.expand(TEST_POSITION + TESTS) def test_value(self, arguments, input_data, expected_data, _): cropper = CropForegroundd(**arguments) diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index de1122eeae..6d0f2319fb 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -21,6 +21,7 @@ class TestCrossValidation(unittest.TestCase): + @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py index 82a0f7afbd..71be4fdd22 100644 --- a/tests/test_csv_dataset.py +++ b/tests/test_csv_dataset.py @@ -23,6 +23,7 @@ class TestCSVDataset(unittest.TestCase): + def test_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data1 = [ diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index 65a0a420a5..e06da0c41b 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -26,6 +26,7 @@ @skip_if_windows class TestCSVIterableDataset(unittest.TestCase): + def test_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data1 = [ diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index 833d1134cf..234b3f1057 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -23,6 +23,7 @@ class TestCSVSaver(unittest.TestCase): + def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: saver = CSVSaver(output_dir=tempdir, filename="predictions.csv", delimiter="\t") diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py index 6ebfd8bac7..d2dcc6aa5f 100644 --- a/tests/test_cucim_dict_transform.py +++ b/tests/test_cucim_dict_transform.py @@ -66,6 +66,7 @@ @unittest.skipUnless(HAS_CUPY, "CuPy is required.") @unittest.skipUnless(has_cut, "cuCIM transforms are required.") class TestCuCIMDict(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_COLOR_JITTER_1, diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py index 5884358a74..5f16c11589 100644 --- a/tests/test_cucim_transform.py +++ b/tests/test_cucim_transform.py @@ -66,6 +66,7 @@ @unittest.skipUnless(HAS_CUPY, "CuPy is required.") @unittest.skipUnless(has_cut, "cuCIM transforms are required.") class TestCuCIM(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_COLOR_JITTER_1, diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py index 3377fa815c..d3b6ba094c 100644 --- a/tests/test_cumulative.py +++ b/tests/test_cumulative.py @@ -20,6 +20,7 @@ class TestCumulative(unittest.TestCase): + def test_single(self): c = Cumulative() c.extend([2, 3]) diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py index d815d9be77..624da2c7bb 100644 --- a/tests/test_cumulative_average.py +++ b/tests/test_cumulative_average.py @@ -32,6 +32,7 @@ class TestAverageMeter(unittest.TestCase): + @parameterized.expand(TEST_CASE_1) def test_value_all(self, data): # test orig diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py index 17f4164838..30c01c21ee 100644 --- a/tests/test_cumulative_average_dist.py +++ b/tests/test_cumulative_average_dist.py @@ -23,6 +23,7 @@ @SkipIfBeforePyTorchVersion((1, 8)) class DistributedCumulativeAverage(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_value(self): rank = dist.get_rank() diff --git a/tests/test_cv2_dist.py b/tests/test_cv2_dist.py index edd2e1ec42..562c205763 100644 --- a/tests/test_cv2_dist.py +++ b/tests/test_cv2_dist.py @@ -42,6 +42,7 @@ def main_worker(rank, ngpus_per_node, port): @skip_if_no_cuda class TestCV2Dist(unittest.TestCase): + def test_cv2_cuda_ops(self): print_config() ngpus_per_node = torch.cuda.device_count() diff --git a/tests/test_daf3d.py b/tests/test_daf3d.py index 34e25cc6be..d20cb3cfd1 100644 --- a/tests/test_daf3d.py +++ b/tests/test_daf3d.py @@ -42,6 +42,7 @@ @unittest.skipUnless(has_tv, "torchvision not installed") class TestDAF3D(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 6ef51bef92..05453b0694 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -137,6 +137,7 @@ class TestDataStats(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, input_param, input_data, expected_print): transform = DataStats(**input_param) diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index 374bc815ac..ef88300c10 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -157,6 +157,7 @@ class TestDataStatsd(unittest.TestCase): + @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 2ee69687a6..73e27799f7 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -29,6 +29,7 @@ class TestDataLoader(unittest.TestCase): + def test_values(self): datalist = [ {"image": "spleen_19.nii.gz", "label": "spleen_label_19.nii.gz"}, @@ -59,6 +60,7 @@ def test_exception(self, datalist): class _RandomDataset(torch.utils.data.Dataset, Randomizable): + def __getitem__(self, index): return self.R.randint(0, 1000, (1,)) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c7c2b77697..1398009c63 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,6 +30,7 @@ class TestDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_shape(self, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py index afccd129fe..166d888d9e 100644 --- a/tests/test_dataset_func.py +++ b/tests/test_dataset_func.py @@ -20,6 +20,7 @@ class TestDatasetFunc(unittest.TestCase): + def test_seg_values(self): with tempfile.TemporaryDirectory() as tempdir: # prepare test datalist file diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 87538425d5..21cc53de90 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -36,6 +36,7 @@ def test_collate(batch): class TestDatasetSummary(unittest.TestCase): + def test_spacing_intensity(self): set_determinism(seed=0) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 345cc487c5..70a2a6c06c 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -23,6 +23,7 @@ class TestDecathlonDataset(unittest.TestCase): + @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") @@ -79,7 +80,7 @@ def _test_dataset(dataset): self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus")) - try: + with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"): DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", @@ -87,9 +88,6 @@ def _test_dataset(dataset): section="validation", download=False, ) - except RuntimeError as e: - print(str(e)) - self.assertTrue(str(e).startswith("Cannot find dataset directory")) if __name__ == "__main__": diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 26b9e7a4f4..92f7c89e28 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -81,6 +81,7 @@ class TestDeCollate(unittest.TestCase): + def setUp(self) -> None: set_determinism(seed=0) @@ -159,6 +160,7 @@ def test_decollation_list(self, *transforms): class TestBasicDeCollate(unittest.TestCase): + @parameterized.expand(TEST_BASIC) def test_decollation_examples(self, input_val, expected_out): out = decollate_batch(input_val) diff --git a/tests/test_deepedit_interaction.py b/tests/test_deepedit_interaction.py index 5dcc6205f7..8baf4dc827 100644 --- a/tests/test_deepedit_interaction.py +++ b/tests/test_deepedit_interaction.py @@ -40,6 +40,7 @@ def add_one(engine): class TestInteractions(unittest.TestCase): + def run_interaction(self, train): label_names = {"spleen": 1, "background": 0} np.random.seed(0) diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py index 7f4d4eee1e..18d6567fd7 100644 --- a/tests/test_deepedit_transforms.py +++ b/tests/test_deepedit_transforms.py @@ -209,6 +209,7 @@ class TestAddGuidanceFromPointsCustomd(unittest.TestCase): + @parameterized.expand([ADD_GUIDANCE_FROM_POINTS_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = AddGuidanceFromPointsDeepEditd(**arguments) @@ -217,6 +218,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddGuidanceSignalCustomd(unittest.TestCase): + @parameterized.expand([ADD_GUIDANCE_CUSTOM_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = AddGuidanceSignalDeepEditd(**arguments) @@ -225,6 +227,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddInitialSeedPointMissingLabelsd(unittest.TestCase): + @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): seed = 0 @@ -235,6 +238,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddRandomGuidanceCustomd(unittest.TestCase): + @parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = AddRandomGuidanceDeepEditd(**arguments) @@ -244,6 +248,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestDiscardAddGuidanced(unittest.TestCase): + @parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = DiscardAddGuidanced(**arguments) @@ -252,6 +257,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestFindAllValidSlicesMissingLabelsd(unittest.TestCase): + @parameterized.expand([FIND_SLICE_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = FindAllValidSlicesMissingLabelsd(**arguments) @@ -260,6 +266,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestFindDiscrepancyRegionsCustomd(unittest.TestCase): + @parameterized.expand([FIND_DISCREPANCY_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = FindDiscrepancyRegionsDeepEditd(**arguments) @@ -268,6 +275,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestNormalizeLabelsDatasetd(unittest.TestCase): + @parameterized.expand([NormalizeLabelsDatasetd_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = NormalizeLabelsInDatasetd(**arguments) @@ -276,6 +284,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase): + @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = ResizeGuidanceMultipleLabelDeepEditd(**arguments) @@ -284,6 +293,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestSingleLabelSelectiond(unittest.TestCase): + @parameterized.expand([SingleLabelSelectiond_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = SingleLabelSelectiond(**arguments) @@ -292,6 +302,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestSplitPredsLabeld(unittest.TestCase): + @parameterized.expand([SplitPredsLabeld_TEST_CASE]) def test_correct_results(self, arguments, input_data, expected_result): add_fn = SplitPredsLabeld(**arguments) diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index d8a412ade9..b8d630960c 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -51,6 +51,7 @@ class TestCreateDataset(unittest.TestCase): + def setUp(self): set_determinism(1) self.tempdir = tempfile.mkdtemp() diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 7cdbeed9f9..35759699f8 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -38,6 +38,7 @@ def add_one(engine): class TestInteractions(unittest.TestCase): + def run_interaction(self, train, compose): data = [{"image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2))} for _ in range(5)] network = torch.nn.Linear(2, 2) diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 1328e13439..a491a8004b 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -337,6 +337,7 @@ class TestFindAllValidSlicesd(unittest.TestCase): + @parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2]) def test_correct_results(self, arguments, input_data, expected_result): result = FindAllValidSlicesd(**arguments)(input_data) @@ -344,6 +345,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestSpatialCropForegroundd(unittest.TestCase): + @parameterized.expand([CROP_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = SpatialCropForegroundd(**arguments)(input_data) @@ -368,6 +370,7 @@ def test_foreground_position(self, arguments, input_data, _): class TestAddInitialSeedPointd(unittest.TestCase): + @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): seed = 0 @@ -378,6 +381,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddGuidanceSignald(unittest.TestCase): + @parameterized.expand([ADD_GUIDANCE_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = AddGuidanceSignald(**arguments)(input_data) @@ -385,6 +389,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestFindDiscrepancyRegionsd(unittest.TestCase): + @parameterized.expand([FIND_DISCREPANCY_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = FindDiscrepancyRegionsd(**arguments)(input_data) @@ -392,6 +397,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddRandomGuidanced(unittest.TestCase): + @parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): seed = 0 @@ -402,6 +408,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestAddGuidanceFromPointsd(unittest.TestCase): + @parameterized.expand( [ ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1, @@ -419,6 +426,7 @@ def test_correct_results(self, arguments, input_data, expected_pos, expected_neg class TestSpatialCropGuidanced(unittest.TestCase): + @parameterized.expand( [SPATIAL_CROP_GUIDANCE_TEST_CASE_1, SPATIAL_CROP_GUIDANCE_TEST_CASE_2, SPATIAL_CROP_GUIDANCE_TEST_CASE_3] ) @@ -428,6 +436,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestResizeGuidanced(unittest.TestCase): + @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = ResizeGuidanced(**arguments)(input_data) @@ -435,6 +444,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestRestoreLabeld(unittest.TestCase): + @parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2]) def test_correct_results(self, arguments, input_data, expected_result): result = RestoreLabeld(**arguments)(input_data) @@ -442,6 +452,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestFetch2DSliced(unittest.TestCase): + @parameterized.expand([FETCH_2D_SLICE_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = Fetch2DSliced(**arguments)(input_data) diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index 1ec77f29fd..c57184cd9f 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -28,6 +28,7 @@ class TestDeleteItemsd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_memory(self, input_param, expected_key_size): input_data = {"image": {}} if "sep" in input_param else {} diff --git a/tests/test_denseblock.py b/tests/test_denseblock.py index c14ca2ae7a..b741582422 100644 --- a/tests/test_denseblock.py +++ b/tests/test_denseblock.py @@ -20,6 +20,7 @@ class TestDenseBlock2D(TorchImageTestCase2D): + def test_block_empty(self): block = DenseBlock([]) out = block(self.imt) @@ -36,6 +37,7 @@ def test_block_conv(self): class TestDenseBlock3D(TorchImageTestCase3D): + def test_block_conv(self): conv1 = nn.Conv3d(self.input_channels, self.output_channels, 3, padding=1) conv2 = nn.Conv3d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1) @@ -52,6 +54,7 @@ def test_block_conv(self): class TestConvDenseBlock2D(TorchImageTestCase2D): + def test_block_empty(self): conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[]) out = conv(self.imt) @@ -79,6 +82,7 @@ def test_block2(self): class TestConvDenseBlock3D(TorchImageTestCase3D): + def test_block_empty(self): conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=[]) out = conv(self.imt) diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 1b44baf0c2..ee4be9003b 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -79,6 +79,7 @@ class TestPretrainedDENSENET(unittest.TestCase): + @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @skip_if_quick def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape): @@ -103,6 +104,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape): class TestDENSENET(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_densenet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 5d511f3821..3171a67e2c 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -18,6 +18,7 @@ class TestDeprecatedRC(unittest.TestCase): + def setUp(self): self.test_version_rc = "0.6.0rc1" self.test_version = "0.6.0" @@ -61,6 +62,7 @@ def foo3(): class TestDeprecated(unittest.TestCase): + def setUp(self): self.test_version = "0.5.3+96.g1fa03c2.dirty" self.prev_version = "0.4.3+96.g1fa03c2.dirty" @@ -142,6 +144,7 @@ def test_meth_warning1(self): """Test deprecated decorator with just `since` set.""" class Foo5: + @deprecated(since=self.prev_version, version_val=self.test_version) def meth1(self): pass @@ -152,6 +155,7 @@ def test_meth_except1(self): """Test deprecated decorator with just `since` set.""" class Foo6: + @deprecated(version_val=self.test_version) def meth1(self): pass @@ -389,6 +393,7 @@ def test_deprecated_arg_default_errors(self): # since > replaced def since_grater_than_replaced(): + @deprecated_arg_default( "b", old_default="a", @@ -404,6 +409,7 @@ def foo(a, b=None): # argname doesnt exist def argname_doesnt_exist(): + @deprecated_arg_default( "other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version ) @@ -414,6 +420,7 @@ def foo(a, b=None): # argname has no default def argname_has_no_default(): + @deprecated_arg_default( "a", old_default="a", @@ -429,6 +436,7 @@ def foo(a): # new default is used but version < replaced def argname_was_replaced_before_specified_version(): + @deprecated_arg_default( "a", old_default="a", diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index 105d3a4ace..e2efefeb77 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -116,6 +116,7 @@ @SkipIfNoModule("torch.fft") class TestDetectEnvelope(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_1D_SINE, @@ -151,6 +152,7 @@ def test_value_error(self, arguments, image, method): @SkipIfModule("torch.fft") class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): self.assertRaises(OptionalImportError, DetectEnvelope(), np.random.rand(1, 10)) diff --git a/tests/test_detection_coco_metrics.py b/tests/test_detection_coco_metrics.py index 780031ee0c..a85eb37db7 100644 --- a/tests/test_detection_coco_metrics.py +++ b/tests/test_detection_coco_metrics.py @@ -23,6 +23,7 @@ class TestCOCOMetrics(unittest.TestCase): + def test_coco_run(self): coco_metric = COCOMetric(classes=["c0", "c1", "c2"], iou_list=[0.1], max_detection=[10]) diff --git a/tests/test_detector_boxselector.py b/tests/test_detector_boxselector.py index 8cc9b15911..326ecd5773 100644 --- a/tests/test_detector_boxselector.py +++ b/tests/test_detector_boxselector.py @@ -56,6 +56,7 @@ class TestBoxSelector(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_box_selector(self, input_param, boxes, logits, image_shape, expected_results): box_selector = BoxSelector(**input_param) diff --git a/tests/test_detector_utils.py b/tests/test_detector_utils.py index 41716934b5..352e1c2faf 100644 --- a/tests/test_detector_utils.py +++ b/tests/test_detector_utils.py @@ -79,6 +79,7 @@ class TestDetectorUtils(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_detector_utils(self, input_param, input_shape, expected_shape): size_divisible = 32 * ensure_tuple(input_param["conv1_t_stride"])[0] diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py index 97028f2597..44c4d2c598 100644 --- a/tests/test_dev_collate.py +++ b/tests/test_dev_collate.py @@ -36,6 +36,7 @@ class DevCollateTest(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_dev_collate(self, inputs, msg): with self.assertLogs(level=logging.CRITICAL) as log: diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 58b9f4c191..97c7ae5050 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -86,16 +86,27 @@ class TestDiceCELoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): diceceloss = DiceCELoss(**input_param) result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - # def test_ill_shape(self): - # loss = DiceCELoss() - # with self.assertRaisesRegex(ValueError, ""): - # loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + def test_ill_shape(self): + loss = DiceCELoss() + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) # def test_ill_reduction(self): # with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 845ef40cd5..814a174762 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -22,6 +22,7 @@ class TestDiceFocalLoss(unittest.TestCase): + def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) @@ -68,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot): def test_ill_shape(self): loss = DiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 370d2dd5af..14aa6ec241 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -168,6 +168,7 @@ class TestDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = DiceLoss(**input_param).forward(**input_data) diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py index 05dfab95fb..93df77cc51 100644 --- a/tests/test_diffusion_loss.py +++ b/tests/test_diffusion_loss.py @@ -79,6 +79,7 @@ class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = DiffusionLoss(**input_param).forward(**input_data) diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py index 21cef39d68..13990da373 100644 --- a/tests/test_dints_cell.py +++ b/tests/test_dints_cell.py @@ -98,6 +98,7 @@ class TestCell(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D + TEST_CASES_3D) def test_cell_3d(self, input_param, ops, weight, input_shape, expected_shape): net = Cell(**input_param) diff --git a/tests/test_dints_mixop.py b/tests/test_dints_mixop.py index 09d2e7a423..683a8d1005 100644 --- a/tests/test_dints_mixop.py +++ b/tests/test_dints_mixop.py @@ -61,6 +61,7 @@ class TestMixOP(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_mixop_3d(self, input_param, ops, weight, input_shape, expected_shape): net = MixedOp(ops=Cell.OPS3D, **input_param) diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py index 532c31886b..5ee4db7a4e 100644 --- a/tests/test_dints_network.py +++ b/tests/test_dints_network.py @@ -115,6 +115,7 @@ @skip_if_quick class TestDints(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) def test_dints_inference(self, dints_grid_params, dints_params, input_shape, expected_shape): grid = TopologySearch(**dints_grid_params) @@ -155,6 +156,7 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect @SkipIfBeforePyTorchVersion((1, 9)) class TestDintsTS(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) def test_script(self, dints_grid_params, dints_params, input_shape, _): grid = TopologyInstance(**dints_grid_params) diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 62635e286e..f615605e56 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -42,6 +42,7 @@ class TestDiscriminator(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Discriminator(**input_param) diff --git a/tests/test_distance_transform_edt.py b/tests/test_distance_transform_edt.py index 83b9348348..cf5c253c0c 100644 --- a/tests/test_distance_transform_edt.py +++ b/tests/test_distance_transform_edt.py @@ -146,6 +146,7 @@ class TestDistanceTransformEDT(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_scipy_transform(self, input, expected_output): transform = DistanceTransformEDT() diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index 696bcfc78f..555f7dc250 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -24,6 +24,7 @@ class TestDownloadAndExtract(unittest.TestCase): + @skip_if_quick def test_actions(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") diff --git a/tests/test_download_url_yandex.py b/tests/test_download_url_yandex.py index a08105a93f..54d39b06ff 100644 --- a/tests/test_download_url_yandex.py +++ b/tests/test_download_url_yandex.py @@ -29,6 +29,7 @@ class TestDownloadUrlYandex(unittest.TestCase): + @unittest.skip("data source unstable") def test_verify(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index cd40be4306..34afa248ad 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -37,6 +37,7 @@ class TestMaxAvgPool(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = MaxAvgPool(**input_param) diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py index ab2150e548..1b9974791a 100644 --- a/tests/test_drop_path.py +++ b/tests/test_drop_path.py @@ -28,6 +28,7 @@ class TestDropPath(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape): im = torch.rand(input_shape) diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py index de7aec1ced..daa4ed1e5e 100644 --- a/tests/test_ds_loss.py +++ b/tests/test_ds_loss.py @@ -135,6 +135,7 @@ class TestDSLossDiceCE(unittest.TestCase): + @parameterized.expand(TEST_CASES_DICECE) def test_result(self, input_param, input_param2, input_data, expected_val): diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2) @@ -160,6 +161,7 @@ def test_script(self): @SkipIfBeforePyTorchVersion((1, 11)) class TestDSLossDiceCE2(unittest.TestCase): + @parameterized.expand(TEST_CASES_DICECE2) def test_result(self, input_param, input_param2, input_data, expected_val): diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2) @@ -169,6 +171,7 @@ def test_result(self, input_param, input_param2, input_data, expected_val): @SkipIfBeforePyTorchVersion((1, 11)) class TestDSLossDice(unittest.TestCase): + @parameterized.expand(TEST_CASES_DICE) def test_result(self, input_param, input_data, expected_val): loss = DeepSupervisionLoss(DiceLoss(**input_param)) @@ -178,6 +181,7 @@ def test_result(self, input_param, input_data, expected_val): @SkipIfBeforePyTorchVersion((1, 11)) class TestDSLossDiceFocal(unittest.TestCase): + @parameterized.expand(TEST_CASES_DICEFOCAL) def test_result(self, input_param, input_data, expected_val): loss = DeepSupervisionLoss(DiceFocalLoss(**input_param)) diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index f18b5b7297..b385b897e5 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -42,6 +42,7 @@ class TestDVF2DDF(unittest.TestCase): + def setUp(self): set_determinism(0) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 247da14b7d..f3c982056c 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -11,6 +11,7 @@ from __future__ import annotations +import platform import unittest from typing import Any, Sequence @@ -24,6 +25,12 @@ InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") +ON_AARCH64 = platform.machine() == "aarch64" +if ON_AARCH64: + rtol, atol = 1e-2, 1e-2 +else: + rtol, atol = 1e-4, 1e-4 + device = "cuda" if torch.cuda.is_available() else "cpu" strides: Sequence[Sequence[int] | int] @@ -109,6 +116,7 @@ class TestDynUNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_DYNUNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) @@ -128,6 +136,7 @@ def test_script(self): @skip_if_no_cuda @skip_if_windows class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): + def setUp(self): try: layer = InstanceNorm3dNVFuser(num_features=1, affine=False).to("cuda:0") @@ -157,10 +166,11 @@ def test_consistency(self, input_param, input_shape, _): with eval_mode(net_fuser): result_fuser = net_fuser(input_tensor) - assert_allclose(result, result_fuser, rtol=1e-4, atol=1e-4) + assert_allclose(result, result_fuser, rtol=rtol, atol=atol) class TestDynUNetDeepSupervision(unittest.TestCase): + @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index b34ccb31ba..4d9e06670b 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -73,6 +73,7 @@ class TestResBasicBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_RES_BASIC_BLOCK) def test_shape(self, input_param, input_shape, expected_shape): for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]: @@ -96,6 +97,7 @@ def test_script(self): class TestUpBlock(unittest.TestCase): + @parameterized.expand(TEST_UP_BLOCK) def test_shape(self, input_param, input_shape, expected_shape, skip_shape): net = UnetUpBlock(**input_param) diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 5bdad5a568..c16526eaa3 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -248,6 +248,7 @@ def make_shape_cases( class TestEFFICIENTNET(unittest.TestCase): + @parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -376,6 +377,7 @@ def test_script(self): class TestExtractFeatures(unittest.TestCase): + @parameterized.expand(CASE_EXTRACT_FEATURES) def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 40a4a72dd5..ad81d35d52 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -26,11 +26,13 @@ class TestEnsembleEvaluator(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_content(self, pred_keys): device = torch.device("cpu:0") class TestDataset(torch.utils.data.Dataset): + def __len__(self): return 8 @@ -40,6 +42,7 @@ def __getitem__(self, index): val_loader = torch.utils.data.DataLoader(TestDataset()) class TestNet(torch.nn.Module): + def __init__(self, func): super().__init__() self.func = func diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 027b18b7dd..0c9ad5869e 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -46,6 +46,7 @@ class TestEnsureChannelFirst(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) @unittest.skipUnless(has_itk, "itk not installed") def test_load_nifti(self, input_param, filenames, original_channel_dim): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 08e2709641..63a437894b 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -32,6 +32,7 @@ class TestEnsureChannelFirstd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_load_nifti(self, input_param, filenames, original_channel_dim): if original_channel_dim is None: diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py index dc6649ec4c..ec8c92785a 100644 --- a/tests/test_ensure_tuple.py +++ b/tests/test_ensure_tuple.py @@ -37,6 +37,7 @@ class TestEnsureTuple(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input, expected_value, wrap_array=False): result = ensure_tuple(input, wrap_array) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 7d6b7ca586..00b01898b3 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -22,6 +22,7 @@ class TestEnsureType(unittest.TestCase): + def test_array_input(self): test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] if torch.cuda.is_available(): diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 4fa942e742..09aa1f04b5 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -22,6 +22,7 @@ class TestEnsureTyped(unittest.TestCase): + def test_array_input(self): test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] if torch.cuda.is_available(): diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py index 5a63fc05af..cd3119f91c 100644 --- a/tests/test_enum_bound_interp.py +++ b/tests/test_enum_bound_interp.py @@ -22,6 +22,7 @@ @skip_if_no_cpp_extension class TestEnumBoundInterp(unittest.TestCase): + def test_bound(self): self.assertEqual(str(b.replicate), "BoundType.replicate") self.assertEqual(str(b.nearest), "BoundType.replicate") diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py index 8458753e1f..b40bb78327 100644 --- a/tests/test_eval_mode.py +++ b/tests/test_eval_mode.py @@ -19,6 +19,7 @@ class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): t = torch.rand(1, 1, 4, 4) p = torch.nn.Conv2d(1, 1, 3) diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index f338944daa..d6d26c7e23 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -21,6 +21,7 @@ class DistributedEvenlyDivisibleAllGather(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_data(self): self._run() diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py index f7642ff357..b082c70090 100644 --- a/tests/test_factorized_increase.py +++ b/tests/test_factorized_increase.py @@ -25,6 +25,7 @@ class TestFactInc(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): net = FactorizedIncreaseBlock(**input_param) diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py index 224a0cb351..5e879c3cb5 100644 --- a/tests/test_factorized_reduce.py +++ b/tests/test_factorized_reduce.py @@ -25,6 +25,7 @@ class TestFactRed(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape): net = FactorizedReduceBlock(**input_param) diff --git a/tests/test_fastmri_reader.py b/tests/test_fastmri_reader.py index b15bd4b6a2..af2eed7db5 100644 --- a/tests/test_fastmri_reader.py +++ b/tests/test_fastmri_reader.py @@ -65,6 +65,7 @@ class TestMRIUtils(unittest.TestCase): + @parameterized.expand([TEST_CASE1, TEST_CASE2]) def test_get_data(self, test_data, test_res, test_meta): reader = FastMRIReader() diff --git a/tests/test_fft_utils.py b/tests/test_fft_utils.py index 971df2b411..7c7035770a 100644 --- a/tests/test_fft_utils.py +++ b/tests/test_fft_utils.py @@ -42,6 +42,7 @@ class TestFFT(unittest.TestCase): + @parameterized.expand(TESTS) def test(self, test_data, res_data): result = fftn_centered(test_data, spatial_dims=2, is_complex=False) diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 7d88bb7ee9..a28c491333 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -72,6 +72,7 @@ class TestFgBgToIndices(unittest.TestCase): + @parameterized.expand(TESTS_CASES) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index d0d1ae5fb6..c6dd2059f4 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -67,6 +67,7 @@ class TestFgBgToIndicesd(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index 93e2027575..27e2d98c7d 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -20,6 +20,7 @@ class TestFilename(unittest.TestCase): + def test_value(self): with tempfile.TemporaryDirectory() as tempdir: output_tmp = os.path.join(tempdir, "output") diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 65c59d49eb..241f7f8254 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -195,6 +195,7 @@ class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_image, expected): converter = FillHoles(**args) diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py index 3f98dab1bf..28c17b00ac 100644 --- a/tests/test_fill_holesd.py +++ b/tests/test_fill_holesd.py @@ -196,6 +196,7 @@ class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_image, expected): key = CommonKeys.IMAGE diff --git a/tests/test_fl_exchange_object.py b/tests/test_fl_exchange_object.py index 293f9d518b..dab4eae037 100644 --- a/tests/test_fl_exchange_object.py +++ b/tests/test_fl_exchange_object.py @@ -46,6 +46,7 @@ @SkipIfNoModule("torchvision") @SkipIfNoModule("ignite") class TestFLExchangeObject(unittest.TestCase): + @parameterized.expand([TEST_INIT_1, TEST_INIT_2]) def test_init(self, input_params, expected_str): eo = ExchangeObject(**input_params) diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index ca781ff166..c8cb3451fc 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -75,7 +75,7 @@ tracking={ "handlers_id": DEFAULT_HANDLERS_ID, "configs": { - "execute_config": f"{_data_dir}/config_executed.json", + "save_execute_config": f"{_data_dir}/config_executed.json", "trainer": { "_target_": "MLFlowHandler", "tracking_uri": path_to_uri(_data_dir) + "/mlflow_override", @@ -181,6 +181,7 @@ @SkipIfNoModule("ignite") @SkipIfNoModule("mlflow") class TestFLMonaiAlgo(unittest.TestCase): + @parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4]) def test_train(self, input_params): # initialize algo @@ -200,7 +201,7 @@ def test_train(self, input_params): algo.finalize() # test experiment management - if "execute_config" in algo.train_workflow.parser: + if "save_execute_config" in algo.train_workflow.parser: self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override")) shutil.rmtree(f"{_data_dir}/mlflow_override") self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json")) @@ -223,7 +224,7 @@ def test_evaluate(self, input_params): algo.evaluate(data=data, extra={}) # test experiment management - if "execute_config" in algo.eval_workflow.parser: + if "save_execute_config" in algo.eval_workflow.parser: self.assertGreater(len(list(glob.glob(f"{_data_dir}/mlflow_*"))), 0) for f in list(glob.glob(f"{_data_dir}/mlflow_*")): shutil.rmtree(f) diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py index 1302ab6618..d8dbfa5339 100644 --- a/tests/test_fl_monai_algo_dist.py +++ b/tests/test_fl_monai_algo_dist.py @@ -32,6 +32,7 @@ @SkipIfNoModule("ignite") @SkipIfBeforePyTorchVersion((1, 11, 1)) class TestFLMonaiAlgo(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2, init_method="no_init") @skip_if_no_cuda def test_train(self): diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py index 307b3f539c..6e58f8af88 100644 --- a/tests/test_fl_monai_algo_stats.py +++ b/tests/test_fl_monai_algo_stats.py @@ -64,6 +64,7 @@ @SkipIfNoModule("ignite") class TestFLMonaiAlgo(unittest.TestCase): + @parameterized.expand([TEST_GET_DATA_STATS_1, TEST_GET_DATA_STATS_2, TEST_GET_DATA_STATS_3]) def test_get_data_stats(self, input_params): # initialize algo diff --git a/tests/test_flatten_sub_keysd.py b/tests/test_flatten_sub_keysd.py index 997f203870..1a642e5fc4 100644 --- a/tests/test_flatten_sub_keysd.py +++ b/tests/test_flatten_sub_keysd.py @@ -46,6 +46,7 @@ class TestFlattenSubKeysd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_dict(self, params, input_data, expected): result = FlattenSubKeysd(**params)(input_data) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 1d831f0976..42baa28b71 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -23,18 +23,18 @@ EfficientNetBNFeatures, FlexibleUNet, FlexUNetEncoderRegister, - ResNet, - ResNetBlock, - ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, ) from monai.utils import optional_import -from tests.utils import skip_if_downloading_fails, skip_if_quick +from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick torchvision, has_torchvision = optional_import("torchvision") PIL, has_pil = optional_import("PIL") class DummyEncoder(BaseEncoder): + @classmethod def get_encoder_parameters(cls): basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False} @@ -58,101 +58,6 @@ def get_encoder_names(cls): return ["encoder_wrong_channels", "encoder_no_param1", "encoder_no_param2", "encoder_no_param3"] -class ResNetEncoder(ResNet, BaseEncoder): - backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] - output_feature_channels = [(64, 128, 256, 512)] * 3 + [(256, 512, 1024, 2048)] * 4 - parameter_layers = [ - [1, 1, 1, 1], - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3], - [3, 24, 36, 3], - ] - - def __init__(self, in_channels, pretrained, **kargs): - super().__init__(**kargs, n_input_channels=in_channels) - if pretrained: - # Author of paper zipped the state_dict on googledrive, - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). - # Would like to load dict from url but need somewhere to save the state dicts. - raise NotImplementedError( - "Currently not implemented. You need to manually download weights provided by the paper's author" - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" - ) - - @staticmethod - def get_inplanes(): - return [64, 128, 256, 512] - - @classmethod - def get_encoder_parameters(cls) -> list[dict]: - """ - Get parameter list to initialize encoder networks. - Each parameter dict must have `spatial_dims`, `in_channels` - and `pretrained` parameters. - """ - parameter_list = [] - res_type: type[ResNetBlock] | type[ResNetBottleneck] - for backbone in range(len(cls.backbone_names)): - if backbone < 3: - res_type = ResNetBlock - else: - res_type = ResNetBottleneck - parameter_list.append( - { - "block": res_type, - "layers": cls.parameter_layers[backbone], - "block_inplanes": ResNetEncoder.get_inplanes(), - "spatial_dims": 2, - "in_channels": 3, - "pretrained": False, - } - ) - return parameter_list - - @classmethod - def num_channels_per_output(cls): - """ - Get number of output features' channel. - """ - return cls.output_feature_channels - - @classmethod - def num_outputs(cls): - """ - Get number of output feature. - """ - return [4] * 7 - - @classmethod - def get_encoder_names(cls): - """ - Get the name string of backbones which will be used to initialize flexible unet. - """ - return cls.backbone_names - - def forward(self, x: torch.Tensor): - feature_list = [] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - if not self.no_max_pool: - x = self.maxpool(x) - x = self.layer1(x) - feature_list.append(x) - x = self.layer2(x) - feature_list.append(x) - x = self.layer3(x) - feature_list.append(x) - x = self.layer4(x) - feature_list.append(x) - - return feature_list - - -FLEXUNET_BACKBONE.register_class(ResNetEncoder) FLEXUNET_BACKBONE.register_class(DummyEncoder) @@ -203,9 +108,7 @@ def make_shape_cases( def make_error_case(): - error_dummy_backbones = DummyEncoder.get_encoder_names() - error_resnet_backbones = ResNetEncoder.get_encoder_names() - error_backbones = error_dummy_backbones + error_resnet_backbones + error_backbones = DummyEncoder.get_encoder_names() error_param_list = [] for backbone in error_backbones: error_param_list.append( @@ -231,7 +134,7 @@ def make_error_case(): norm="instance", ) CASES_3D = make_shape_cases( - models=[SEL_MODELS[0]], + models=[SEL_MODELS[0], SEL_MODELS[2]], spatial_dims=[3], batches=[1], pretrained=[False], @@ -344,6 +247,7 @@ def make_error_case(): "spatial_dims": 2, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, + EfficientNetBNFeatures, { "in_channels": 3, "num_classes": 10, @@ -353,7 +257,20 @@ def make_error_case(): "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, ["_conv_stem.weight"], - ) + ), + ( + { + "in_channels": 1, + "out_channels": 10, + "backbone": SEL_MODELS[2], + "pretrained": True, + "spatial_dims": 3, + "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), + }, + ResNetFeatures, + {"model_name": SEL_MODELS[2], "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + ["conv1.weight"], + ), ] CASE_ERRORS = make_error_case() @@ -362,8 +279,10 @@ def make_error_case(): CASE_REGISTER_ENCODER = ["EfficientNetEncoder", "monai.networks.nets.EfficientNetEncoder"] +@SkipIfNoModule("hf_hub_download") @skip_if_quick class TestFLEXIBLEUNET(unittest.TestCase): + @parameterized.expand(CASES_2D + CASES_3D + CASES_VARIATIONS) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -379,19 +298,19 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @parameterized.expand(CASES_PRETRAIN) - def test_pretrain(self, input_param, efficient_input_param, weight_list): + def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_extractor_input_param, weight_list): device = "cuda" if torch.cuda.is_available() else "cpu" with skip_if_downloading_fails(): - net = FlexibleUNet(**input_param).to(device) + net = FlexibleUNet(**flexunet_input_param).to(device) with skip_if_downloading_fails(): - eff_net = EfficientNetBNFeatures(**efficient_input_param).to(device) + feature_extractor_net = feature_extractor_class(**feature_extractor_input_param).to(device) for weight_name in weight_list: - if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict(): + if weight_name in net.encoder.state_dict() and weight_name in feature_extractor_net.state_dict(): net_weight = net.encoder.state_dict()[weight_name] - download_weight = eff_net.state_dict()[weight_name] + download_weight = feature_extractor_net.state_dict()[weight_name] weight_diff = torch.abs(net_weight - download_weight) diff_sum = torch.sum(weight_diff) # check if a weight in weight_list equals to the downloaded weight. @@ -404,6 +323,7 @@ def test_error_raise(self, input_param): class TestFlexUNetEncoderRegister(unittest.TestCase): + @parameterized.expand(CASE_REGISTER_ENCODER) def test_regist(self, encoder): tmp_backbone = FlexUNetEncoderRegister() diff --git a/tests/test_flip.py b/tests/test_flip.py index d7df55fde0..789ec86920 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -34,6 +34,7 @@ class TestFlip(NumpyImageTestCase2D): + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 19f9ed0882..277f387051 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -35,6 +35,7 @@ class TestFlipd(NumpyImageTestCase2D): + @parameterized.expand(INVALID_CASES) def test_invalid_cases(self, _, spatial_axis, raises): with self.assertRaises(raises): diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 57df6a3460..0bb8a078ae 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -79,6 +79,7 @@ class TestFocalLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): focal_loss = FocalLoss(**input_param) @@ -131,7 +132,7 @@ def test_consistency_with_cross_entropy_2d_no_reduction(self): error = np.abs(a - b) max_error = np.maximum(error, max_error) - assert np.allclose(max_error, 0) + assert np.allclose(max_error, 0, atol=1e-6) def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" diff --git a/tests/test_folder_layout.py b/tests/test_folder_layout.py index d6d4bdf679..6f72eee51f 100644 --- a/tests/test_folder_layout.py +++ b/tests/test_folder_layout.py @@ -60,6 +60,7 @@ class TestFolderLayout(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value(self, con_params, f_params, expected): fname = FolderLayout(**con_params).filename(**f_params) diff --git a/tests/test_foreground_mask.py b/tests/test_foreground_mask.py index eb59ae2db6..1aa54f4d3a 100644 --- a/tests/test_foreground_mask.py +++ b/tests/test_foreground_mask.py @@ -81,6 +81,7 @@ @unittest.skipUnless(has_skimage, "Requires sci-kit image") class TestForegroundMask(unittest.TestCase): + @parameterized.expand(TESTS) def test_foreground_mask(self, in_type, arguments, image, mask): input_image = in_type(image) diff --git a/tests/test_foreground_maskd.py b/tests/test_foreground_maskd.py index 24cb233c30..dc7b6cfb24 100644 --- a/tests/test_foreground_maskd.py +++ b/tests/test_foreground_maskd.py @@ -89,6 +89,7 @@ @unittest.skipUnless(has_skimage, "Requires sci-kit image") class TestForegroundMaskd(unittest.TestCase): + @parameterized.expand(TESTS) def test_foreground_mask(self, in_type, arguments, data_dict, mask): data_dict[arguments["keys"]] = in_type(data_dict[arguments["keys"]]) diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 3613db989f..177fc280f7 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -28,6 +28,7 @@ @SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestFourier(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_fpn_block.py b/tests/test_fpn_block.py index c6121c5b98..969800e80a 100644 --- a/tests/test_fpn_block.py +++ b/tests/test_fpn_block.py @@ -44,6 +44,7 @@ class TestFPNBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_fpn_block(self, input_param, input_shape, expected_shape): net = FeaturePyramidNetwork(**input_param) @@ -67,6 +68,7 @@ def test_script(self, input_param, input_shape, expected_shape): @unittest.skipUnless(has_torchvision, "Requires torchvision") class TestFPN(unittest.TestCase): + @parameterized.expand(TEST_CASES2) def test_fpn(self, input_param, input_shape, expected_shape): net = _resnet_fpn_extractor(backbone=resnet50(), spatial_dims=input_param["spatial_dims"], returned_layers=[1]) diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py index 29594ed98a..1bea4ed1b5 100644 --- a/tests/test_freeze_layers.py +++ b/tests/test_freeze_layers.py @@ -27,6 +27,7 @@ class TestModuleState(unittest.TestCase): + def tearDown(self): set_determinism(None) diff --git a/tests/test_from_engine_hovernet.py b/tests/test_from_engine_hovernet.py index 227fa66baa..7d1a784466 100644 --- a/tests/test_from_engine_hovernet.py +++ b/tests/test_from_engine_hovernet.py @@ -28,6 +28,7 @@ class TestFromEngineHovernet(unittest.TestCase): + @parameterized.expand(CASES) def test_results(self, input, expected): output = from_engine_hovernet(keys=["A", "B"], nested_key="C")(input) diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index 94fc4caa6e..863d1399a9 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -42,6 +42,7 @@ class TestFullyConnectedNet(unittest.TestCase): + def setUp(self): self.batch_size = 10 self.inSize = 10 diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py index b98507b793..689d8088f9 100644 --- a/tests/test_gaussian.py +++ b/tests/test_gaussian.py @@ -224,6 +224,7 @@ class TestGaussian1d(unittest.TestCase): + def test_gaussian(self): np.testing.assert_allclose( gaussian_1d(0.5, 8), diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py index 1beee579e8..2167591c66 100644 --- a/tests/test_gaussian_filter.py +++ b/tests/test_gaussian_filter.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.networks.layers import GaussianFilter -from tests.utils import skip_if_quick +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick TEST_CASES = [[{"type": "erf", "gt": 2.0}], [{"type": "scalespace", "gt": 3.0}], [{"type": "sampled", "gt": 5.0}]] TEST_CASES_GPU = [[{"type": "erf", "gt": 0.8, "device": "cuda"}], [{"type": "sampled", "gt": 5.0, "device": "cuda"}]] @@ -34,7 +34,9 @@ ] +@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445 class TestGaussianFilterBackprop(unittest.TestCase): + def code_to_run(self, input_args): input_dims = input_args.get("dims", (2, 3, 8)) device = ( @@ -93,7 +95,9 @@ def test_train_slow(self, input_args): self.code_to_run(input_args) +@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445 class GaussianFilterTestCase(unittest.TestCase): + def test_1d(self): a = torch.ones(1, 8, 10) g = GaussianFilter(1, 3, 3).to(torch.device("cpu:0")) diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 2509a4fc26..392a7b376b 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -82,6 +82,7 @@ class TestGaussianSharpen(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = GaussianSharpen(**arguments)(image) diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py index 75ea915d96..15b219fd2c 100644 --- a/tests/test_gaussian_sharpend.py +++ b/tests/test_gaussian_sharpend.py @@ -82,6 +82,7 @@ class TestGaussianSharpend(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = GaussianSharpend(**arguments)(image) diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index 38b29bbd17..9f99ebe0f8 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -86,6 +86,7 @@ class TestGaussianSmooth(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = GaussianSmooth(**arguments)(image) diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py index 8702c073c8..a6de4a159b 100644 --- a/tests/test_gaussian_smoothd.py +++ b/tests/test_gaussian_smoothd.py @@ -86,6 +86,7 @@ class TestGaussianSmoothd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = GaussianSmoothd(**arguments)(image) diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index 29f2d0096b..f0a419dcf5 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -64,6 +64,7 @@ class _InplaceXform(Transform): + def __call__(self, data): data[0] = data[0] + 1 return data @@ -73,6 +74,7 @@ def __call__(self, data): @unittest.skipUnless(has_nib, "Requires nibabel package.") @unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") class TestDataset(unittest.TestCase): + def test_cache(self): """testing no inplace change to the hashed item""" for p in TEST_NDARRAYS[:2]: diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py index 33f6653212..65252611ca 100644 --- a/tests/test_generalized_dice_focal_loss.py +++ b/tests/test_generalized_dice_focal_loss.py @@ -21,6 +21,7 @@ class TestGeneralizedDiceFocalLoss(unittest.TestCase): + def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) @@ -58,8 +59,18 @@ def test_result_no_onehot_no_bg(self): def test_ill_shape(self): loss = GeneralizedDiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index d8ba496d03..7499507129 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -142,6 +142,7 @@ class TestGeneralizedDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = GeneralizedDiceLoss(**input_param).forward(**input_data) diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 7b85fdc5b6..6b9d57e831 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -24,6 +24,7 @@ class TestGeneralizedWassersteinDiceLoss(unittest.TestCase): + def test_bin_seg_2d(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) @@ -160,6 +161,7 @@ def test_convergence(self): # define a model with one layer class OnelayerNet(nn.Module): + def __init__(self): super().__init__() self.layer = nn.Linear(num_voxels, num_voxels * num_classes) diff --git a/tests/test_generate_distance_map.py b/tests/test_generate_distance_map.py index 724a335e1a..42f5664647 100644 --- a/tests/test_generate_distance_map.py +++ b/tests/test_generate_distance_map.py @@ -36,6 +36,7 @@ class TestGenerateDistanceMap(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, probmap, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_distance_mapd.py b/tests/test_generate_distance_mapd.py index 17c5aa782b..2bddadf5b8 100644 --- a/tests/test_generate_distance_mapd.py +++ b/tests/test_generate_distance_mapd.py @@ -55,6 +55,7 @@ class TestGenerateDistanceMapd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, border_map, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_instance_border.py b/tests/test_generate_instance_border.py index 8634bb7d77..fc1035dfe5 100644 --- a/tests/test_generate_instance_border.py +++ b/tests/test_generate_instance_border.py @@ -34,6 +34,7 @@ class TestGenerateInstanceBorder(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, hover_map, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_instance_borderd.py b/tests/test_generate_instance_borderd.py index fc81e8f87c..cdfbee4193 100644 --- a/tests/test_generate_instance_borderd.py +++ b/tests/test_generate_instance_borderd.py @@ -44,6 +44,7 @@ class TestGenerateInstanceBorderd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, hover_map, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_instance_centroid.py b/tests/test_generate_instance_centroid.py index f9fdc602a9..6b4d533401 100644 --- a/tests/test_generate_instance_centroid.py +++ b/tests/test_generate_instance_centroid.py @@ -41,6 +41,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceCentroid(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, test_data, offset, expected): inst_bbox = get_bbox(test_data[None]) diff --git a/tests/test_generate_instance_centroidd.py b/tests/test_generate_instance_centroidd.py index 92e45cdf84..d381ad8c0e 100644 --- a/tests/test_generate_instance_centroidd.py +++ b/tests/test_generate_instance_centroidd.py @@ -41,6 +41,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceCentroidd(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, test_data, offset, expected): inst_bbox = get_bbox(test_data[None]) diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py index 9058855e62..7f4290747d 100644 --- a/tests/test_generate_instance_contour.py +++ b/tests/test_generate_instance_contour.py @@ -46,6 +46,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContour(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py index 22e3669850..5c831ee680 100644 --- a/tests/test_generate_instance_contourd.py +++ b/tests/test_generate_instance_contourd.py @@ -46,6 +46,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContourd(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) diff --git a/tests/test_generate_instance_type.py b/tests/test_generate_instance_type.py index 354f8640ae..24e1d1b6d0 100644 --- a/tests/test_generate_instance_type.py +++ b/tests/test_generate_instance_type.py @@ -41,6 +41,7 @@ class TestGenerateInstanceType(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, type_pred, seg_pred, bbox, expected): result = GenerateInstanceType()(in_type(type_pred[None]), in_type(seg_pred[None]), bbox, 1) diff --git a/tests/test_generate_instance_typed.py b/tests/test_generate_instance_typed.py index 84a5344503..958f68d6bb 100644 --- a/tests/test_generate_instance_typed.py +++ b/tests/test_generate_instance_typed.py @@ -41,6 +41,7 @@ class TestGenerateInstanceTyped(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, in_type, type_pred, seg_pred, bbox, expected): test_data = {"type_pred": in_type(type_pred[None]), "seg": in_type(seg_pred[None]), "bbox": bbox, "id": 1} diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index c276171bd5..1cbb5f05c3 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -48,6 +48,7 @@ class TestGenerateLabelClassesCropCenters(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): results = [] diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 8301e40188..a78dba9f03 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -68,6 +68,7 @@ class TestGenerateParamGroups(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_lr_values(self, input_param, expected_values, expected_groups): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 13b7b728b4..de127b33df 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -51,6 +51,7 @@ class TestGeneratePosNegLabelCropCenters(unittest.TestCase): + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): results = [] diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index a67e7d0175..6d5b415ec2 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -104,6 +104,7 @@ class TestGenerateSpatialBoundingBox(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) diff --git a/tests/test_generate_succinct_contour.py b/tests/test_generate_succinct_contour.py index 1c60e99546..fc4f5660d9 100644 --- a/tests/test_generate_succinct_contour.py +++ b/tests/test_generate_succinct_contour.py @@ -44,6 +44,7 @@ class TestGenerateSuccinctContour(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, test_data, height, width, expected): result = GenerateSuccinctContour(height=height, width=width)(test_data) diff --git a/tests/test_generate_succinct_contourd.py b/tests/test_generate_succinct_contourd.py index e94a02fed5..7b023d8618 100644 --- a/tests/test_generate_succinct_contourd.py +++ b/tests/test_generate_succinct_contourd.py @@ -45,6 +45,7 @@ class TestGenerateSuccinctContour(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, data, height, width, expected): test_data = {"contour": data} diff --git a/tests/test_generate_watershed_markers.py b/tests/test_generate_watershed_markers.py index a763361913..238fb00ee0 100644 --- a/tests/test_generate_watershed_markers.py +++ b/tests/test_generate_watershed_markers.py @@ -38,6 +38,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestGenerateWatershedMarkers(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, probmap, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_watershed_markersd.py b/tests/test_generate_watershed_markersd.py index 76d4ec1ae6..a3c2b9c231 100644 --- a/tests/test_generate_watershed_markersd.py +++ b/tests/test_generate_watershed_markersd.py @@ -68,6 +68,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestGenerateWatershedMarkersd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, mask, border_map, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_watershed_mask.py b/tests/test_generate_watershed_mask.py index 1cc35dca5c..5224a912b0 100644 --- a/tests/test_generate_watershed_mask.py +++ b/tests/test_generate_watershed_mask.py @@ -58,6 +58,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestGenerateWatershedMask(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generate_watershed_maskd.py b/tests/test_generate_watershed_maskd.py index aa6d5bf03a..9d0f2c274a 100644 --- a/tests/test_generate_watershed_maskd.py +++ b/tests/test_generate_watershed_maskd.py @@ -58,6 +58,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestGenerateWatershedMaskd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) def test_value(self, arguments, exception_type): with self.assertRaises(exception_type): diff --git a/tests/test_generator.py b/tests/test_generator.py index c336acf7ef..f531f928da 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -42,6 +42,7 @@ class TestGenerator(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Generator(**input_param) diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py index 299a3963b7..2b4de1bc2a 100644 --- a/tests/test_get_equivalent_dtype.py +++ b/tests/test_get_equivalent_dtype.py @@ -29,6 +29,7 @@ class TestGetEquivalentDtype(unittest.TestCase): + @parameterized.expand(TESTS) def test_get_equivalent_dtype(self, im, input_dtype): out_dtype = get_equivalent_dtype(input_dtype, type(im)) diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index 1338ba0e2c..e60715e2fe 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -47,6 +47,7 @@ class TestGetExtremePoints(unittest.TestCase): + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected): result = get_extreme_points(**input_data) diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py index ad0be1a5c4..5c020892ed 100644 --- a/tests/test_get_layers.py +++ b/tests/test_get_layers.py @@ -37,6 +37,7 @@ class TestGetLayers(unittest.TestCase): + @parameterized.expand(TEST_CASE_NORM) def test_norm_layer(self, input_param, expected): layer = get_norm_layer(**input_param) @@ -54,6 +55,7 @@ def test_dropout_layer(self, input_param, expected): class TestSuggestion(unittest.TestCase): + def test_suggested(self): with self.assertRaisesRegex(ValueError, "did you mean 'GROUP'?"): get_norm_layer(name="grop", spatial_dims=2) diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index 1881d79602..ab9e69cd31 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -17,6 +17,7 @@ class TestGetVersion(unittest.TestCase): + def test_default(self): output = get_package_version("42foobarnoexist") self.assertTrue("UNKNOWN" in output) diff --git a/tests/test_get_unique_labels.py b/tests/test_get_unique_labels.py index e550882243..0a88145489 100644 --- a/tests/test_get_unique_labels.py +++ b/tests/test_get_unique_labels.py @@ -35,6 +35,7 @@ class TestGetUniqueLabels(unittest.TestCase): + @parameterized.expand(TESTS) def test_correct_results(self, args, expected): result = get_unique_labels(**args) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index aad5d6fea6..bdc66b9495 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -32,6 +32,7 @@ class TestGibbsNoise(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 3aa69b7280..3b2cae7e84 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -33,6 +33,7 @@ class TestGibbsNoised(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_giou_loss.py b/tests/test_giou_loss.py index e794ddab30..34ee22e0ad 100644 --- a/tests/test_giou_loss.py +++ b/tests/test_giou_loss.py @@ -35,6 +35,7 @@ class TestGIoULoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, input_data, expected_val): loss = BoxGIoULoss() diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index b67ed71725..22f5e88431 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -15,6 +15,7 @@ import numpy as np import torch +from parameterized import parameterized from monai import transforms from monai.losses.image_dissimilarity import GlobalMutualInformationLoss @@ -54,6 +55,7 @@ @skip_if_quick class TestGlobalMutualInformationLoss(unittest.TestCase): + def setUp(self): config = testing_data_config("images", "Prostate_T2W_AX_1") download_url_or_skip_test( @@ -114,24 +116,34 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. class TestGlobalMutualInformationLossIll(unittest.TestCase): - def test_ill_shape(self): - loss = GlobalMutualInformationLoss() - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device)) - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device)) - def test_ill_opts(self): + @parameterized.expand( + [ + (torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims + ( + torch.ones((1, 3, 3), dtype=torch.float), + torch.ones((1, 3), dtype=torch.float), + ), # mismatched_advanced_dims + ] + ) + def test_ill_shape(self, input1, input2): + loss = GlobalMutualInformationLoss() + with self.assertRaises(ValueError): + loss.forward(input1, input2) + + @parameterized.expand( + [ + (0, "mean", ValueError, ""), # num_bins_zero + (-1, "mean", ValueError, ""), # num_bins_negative + (64, "unknown", ValueError, ""), # reduction_unknown + (64, None, ValueError, ""), # reduction_none + ] + ) + def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message): pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device) - with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(num_bins=0)(pred, target) - with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(num_bins=-1)(pred, target) - with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(reduction="unknown")(pred, target) - with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(reduction=None)(pred, target) + with self.assertRaisesRegex(expected_exception, expected_message): + GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target) if __name__ == "__main__": diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py index 1ab8db5926..626053377c 100644 --- a/tests/test_globalnet.py +++ b/tests/test_globalnet.py @@ -65,6 +65,7 @@ class TestAffineHead(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE_TRANSFORM) def test_shape(self, input_param, theta, expected_val): layer = AffineHead(**input_param) @@ -78,6 +79,7 @@ def test_shape(self, input_param, theta, expected_val): class TestGlobalNet(unittest.TestCase): + @parameterized.expand(TEST_CASES_GLOBAL_NET) def test_shape(self, input_param, input_shape, expected_shape): net = GlobalNet(**input_param).to(device) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index eb638f5479..549e8f1ec4 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -261,6 +261,7 @@ @skip_if_quick class GMMTestCase(unittest.TestCase): + def setUp(self): self._var = os.environ.get("TORCH_EXTENSIONS_DIR") self.tempdir = tempfile.mkdtemp() diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index d937a5e266..4a3b4b6340 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -58,6 +58,7 @@ def identity_generator(x): class TestGridPatchDataset(unittest.TestCase): + def setUp(self): set_determinism(seed=1234) diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py index 1a698140af..9ec85250e8 100644 --- a/tests/test_grid_distortion.py +++ b/tests/test_grid_distortion.py @@ -99,6 +99,7 @@ class TestGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) def test_grid_distortion(self, input_param, input_data, expected_val): g = GridDistortion(**input_param) diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py index a645eb4f87..ce73593dc7 100644 --- a/tests/test_grid_distortiond.py +++ b/tests/test_grid_distortiond.py @@ -75,6 +75,7 @@ class TestGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask): g = GridDistortiond(**input_param) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index cd1c5b6531..4b324eda1a 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -97,6 +97,7 @@ class TestGridPatch(unittest.TestCase): + @parameterized.expand(TEST_CASES) @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patch(self, in_type, input_parameters, image, expected): diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 4f317e4677..53313b3a8f 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -77,6 +77,7 @@ class TestGridPatchd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index 8877b0c121..f80874d216 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -63,6 +63,7 @@ def make_grid(shape, dtype=None, device=None, requires_grad=True): @skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): + @parameterized.expand(TEST_1D_GP, skip_on_empty=True) def test_grid_pull(self, input_param, expected): result = grid_pull(**input_param) diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py index 3ccf6e75a8..852a4847a6 100644 --- a/tests/test_grid_split.py +++ b/tests/test_grid_split.py @@ -66,6 +66,7 @@ class TestGridSplit(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) def test_split_patch_single_call(self, in_type, input_parameters, image, expected): input_image = in_type(image) diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index d8519b2121..215076d5a3 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -70,6 +70,7 @@ class TestGridSplitd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): input_dict = {} diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 7dfb802bba..7b281665b4 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -23,6 +23,7 @@ class TestHandlerCheckpointLoader(unittest.TestCase): + def test_one_save_one_load(self): net1 = torch.nn.PReLU() data1 = net1.state_dict() diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 70810e018f..42f99e57c9 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -111,6 +111,7 @@ class TestHandlerCheckpointSaver(unittest.TestCase): + @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 905e326a66..5330e48dda 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -26,6 +26,7 @@ class TestHandlerClassificationSaver(unittest.TestCase): + def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: # set up engine diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index ef06b69683..47dca2d999 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -27,6 +27,7 @@ class DistributedHandlerClassificationSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_handler_clearml_image.py b/tests/test_handler_clearml_image.py index 13eebed120..91aa297b7f 100644 --- a/tests/test_handler_clearml_image.py +++ b/tests/test_handler_clearml_image.py @@ -29,6 +29,7 @@ @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") @unittest.skipIf(not has_get_active_config_file, "ClearML 'get_active_config_file' not found") class TestHandlerClearMLImageHandler(unittest.TestCase): + def test_task_init(self): handle, path = tempfile.mkstemp() with open(handle, "w") as new_config: diff --git a/tests/test_handler_clearml_stats.py b/tests/test_handler_clearml_stats.py index a460bc2391..159f6af4eb 100644 --- a/tests/test_handler_clearml_stats.py +++ b/tests/test_handler_clearml_stats.py @@ -29,6 +29,7 @@ @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") @unittest.skipIf(not has_get_active_config_file, "ClearML 'get_active_config_file' not found") class TestHandlerClearMLStatsHandler(unittest.TestCase): + def test_task_init(self): handle, path = tempfile.mkstemp() with open(handle, "w") as new_config: diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index b74b7e57c4..dd30f04142 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -23,6 +23,7 @@ class DistributedConfusionMatrix(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): self._compute() diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 5bc5584515..37ca7f6870 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -22,6 +22,7 @@ class TestHandlerDecollateBatch(unittest.TestCase): + def test_compute(self): data = [ {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index 675a804472..5fbb828330 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -19,7 +19,9 @@ class TestHandlerEarlyStop(unittest.TestCase): + def test_early_stop_train_loss(self): + def _train_func(engine, batch): return {"loss": 1.5} @@ -33,6 +35,7 @@ def _train_func(engine, batch): self.assertEqual(trainer.state.epoch, 2) def test_early_stop_val_metric(self): + def _train_func(engine, batch): pass diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index f64039b6fb..317eba1b11 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -34,6 +34,7 @@ class TestHandlerGarbageCollector(unittest.TestCase): + @skipUnless(has_ignite, "Requires ignite") @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_content(self, data, trigger_event): diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index dbdc765b45..28e0b69621 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -99,6 +99,7 @@ class TestHandlerIgniteMetricHandler(unittest.TestCase): + @SkipIfNoModule("ignite") @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg): diff --git a/tests/test_handler_logfile.py b/tests/test_handler_logfile.py index f09876ab0a..457aca2ebc 100644 --- a/tests/test_handler_logfile.py +++ b/tests/test_handler_logfile.py @@ -30,6 +30,7 @@ class TestHandlerLogfile(unittest.TestCase): + def setUp(self): if has_ignite: # set up engine diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index f1d3f45f06..3efb4a789f 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -25,6 +25,7 @@ class TestHandlerLrSchedule(unittest.TestCase): + def test_content(self): data = [0] * 8 test_lr = 0.1 diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py index 016af1e8b5..06d50e97ff 100644 --- a/tests/test_handler_metric_logger.py +++ b/tests/test_handler_metric_logger.py @@ -28,6 +28,7 @@ class TestHandlerMetricLogger(unittest.TestCase): + @SkipIfNoModule("ignite") def test_metric_logging(self): dummy_name = "dummy" diff --git a/tests/test_handler_metrics_reloaded.py b/tests/test_handler_metrics_reloaded.py index e080204d6f..b8fb39d2e8 100644 --- a/tests/test_handler_metrics_reloaded.py +++ b/tests/test_handler_metrics_reloaded.py @@ -73,6 +73,7 @@ @unittest.skipIf(not has_metrics, "MetricsReloaded not available.") class TestHandlerMetricsReloadedBinary(unittest.TestCase): + @parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3]) def test_compute(self, input_params, y_pred, y, expected_value): input_params["output_transform"] = from_engine(["pred", "label"]) @@ -113,6 +114,7 @@ def test_shape_mismatch(self, input_params, _y_pred, _y, _expected_value): @unittest.skipIf(not has_metrics, "MetricsReloaded not available.") class TestMetricsReloadedCategorical(unittest.TestCase): + @parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2]) def test_compute(self, input_params, y_pred, y, expected_value): input_params["output_transform"] = from_engine(["pred", "label"]) diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py index 9888a73e5f..d5ad2f4841 100644 --- a/tests/test_handler_metrics_saver.py +++ b/tests/test_handler_metrics_saver.py @@ -24,6 +24,7 @@ class TestHandlerMetricsSaver(unittest.TestCase): + def test_content(self): with tempfile.TemporaryDirectory() as tempdir: metrics_saver = MetricsSaver( diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 11d7db168b..46c9ad27d7 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -27,6 +27,7 @@ class DistributedMetricsSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_content(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index 92cf17eadb..44adc49fc2 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -33,6 +33,7 @@ def get_event_filter(e): + def event_filter(_, event): if event in e: return True @@ -65,6 +66,7 @@ def _train_func(engine, batch): class TestHandlerMLFlow(unittest.TestCase): + def setUp(self): self.tmpdir_list = [] diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py index 75cc5bc1f4..a0d1cdb4d5 100644 --- a/tests/test_handler_nvtx.py +++ b/tests/test_handler_nvtx.py @@ -36,6 +36,7 @@ class TestHandlerDecollateBatch(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_compute(self, data, expected): diff --git a/tests/test_handler_panoptic_quality.py b/tests/test_handler_panoptic_quality.py index 1595b5ad2c..337f9c7b49 100644 --- a/tests/test_handler_panoptic_quality.py +++ b/tests/test_handler_panoptic_quality.py @@ -60,6 +60,7 @@ @SkipIfNoModule("scipy.optimize") class TestHandlerPanopticQuality(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_compute(self, input_params, expected_avg): metric = PanopticQuality(**input_params) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 1e7bbb7588..0bcc794381 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -21,6 +21,7 @@ class ToyNet(Module): + def __init__(self, value): super().__init__() self.value = value @@ -36,6 +37,7 @@ def set_value(self, value): class TestHandlerParameterScheduler(unittest.TestCase): + def test_linear_scheduler(self): # Testing step_constant net = ToyNet(value=-1) @@ -116,6 +118,7 @@ def test_multistep_scheduler(self): assert_allclose(net.get_value(), 10 * 0.99 * 0.99) def test_custom_scheduler(self): + def custom_logic(initial_value, gamma, current_step): return initial_value * gamma ** (current_step % 9) diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index c449665c1e..0dd518325b 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -40,6 +40,7 @@ class TestHandlerPostProcessing(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, decollate, expected): data = [ diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index 153a00b1ac..347f8cb92c 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -30,6 +30,7 @@ class TestDataset(Dataset): + def __init__(self, name, size): super().__init__( data=[ @@ -63,11 +64,13 @@ def __getitem__(self, index): class TestEvaluator(Evaluator): + def _iteration(self, engine, batchdata): return batchdata class TestHandlerProbMapGenerator(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_prob_map_generator(self, name, size): # set up dataset diff --git a/tests/test_handler_regression_metrics.py b/tests/test_handler_regression_metrics.py index a06452c54d..a3ec9f071a 100644 --- a/tests/test_handler_regression_metrics.py +++ b/tests/test_handler_regression_metrics.py @@ -46,6 +46,7 @@ def psnrmetric_np(max_val, y_pred, y): class TestHandlerRegressionMetrics(unittest.TestCase): + def test_compute(self): set_determinism(seed=123) device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_handler_regression_metrics_dist.py b/tests/test_handler_regression_metrics_dist.py index a2e96b97d9..f57db429e8 100644 --- a/tests/test_handler_regression_metrics_dist.py +++ b/tests/test_handler_regression_metrics_dist.py @@ -57,6 +57,7 @@ def psnrmetric_np(max_val, y_pred, y): class DistributedMeanSquaredError(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): set_determinism(123) @@ -103,6 +104,7 @@ def _val_func(engine, batch): class DistributedMeanAbsoluteError(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): set_determinism(123) @@ -149,6 +151,7 @@ def _val_func(engine, batch): class DistributedRootMeanSquaredError(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): set_determinism(123) @@ -195,6 +198,7 @@ def _val_func(engine, batch): class DistributedPeakSignalToNoiseRatio(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): set_determinism(123) diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index ce2351a9f5..2c771340f9 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -21,6 +21,7 @@ class TestHandlerROCAUC(unittest.TestCase): + def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 5b6ea045c7..6088251b11 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -23,6 +23,7 @@ class DistributedROCAUC(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) def test_compute(self): auc_metric = ROCAUC() diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index c3b4d72cb4..e544d39c72 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -22,6 +22,7 @@ class TestHandlerSmartCache(unittest.TestCase): + def test_content(self): data = [0, 1, 2, 3, 4, 5, 6, 7, 8] expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]] diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 1842e08635..f876cff2a3 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -26,6 +26,7 @@ def get_event_filter(e): + def event_filter(_, event): if event in e: return True @@ -35,6 +36,7 @@ def event_filter(_, event): class TestHandlerStats(unittest.TestCase): + @parameterized.expand([[True], [get_event_filter([1, 2])]]) def test_metrics_print(self, epoch_log): log_stream = StringIO() diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index 68b71ff7f9..197b175278 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -33,6 +33,7 @@ @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") @SkipIfBeforePyTorchVersion((1, 13)) # issue 6683 class TestHandlerTBImage(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_tb_image_shape(self, shape): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 883827a1ac..b96bea13a1 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -26,6 +26,7 @@ def get_event_filter(e): + def event_filter(_, event): if event in e: return True @@ -36,6 +37,7 @@ def event_filter(_, event): @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") class TestHandlerTBStats(unittest.TestCase): + def test_metrics_print(self): with tempfile.TemporaryDirectory() as tempdir: # set up engine diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index e1ccba2294..752b1d3df7 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -22,12 +22,14 @@ class TestEvaluator(Evaluator): + def _iteration(self, engine, batchdata): engine.state.output = "called" return engine.state.output class TestHandlerValidation(unittest.TestCase): + def test_content(self): data = [0] * 8 diff --git a/tests/test_hardnegsampler.py b/tests/test_hardnegsampler.py index b33cea1537..5385abd1db 100644 --- a/tests/test_hardnegsampler.py +++ b/tests/test_hardnegsampler.py @@ -37,6 +37,7 @@ class TestSampleSlices(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_shape(self, target_label0, target_label1, concat_fg_probs, expected_result_pos, expected_result_neg): compute_dtypes = [torch.float16, torch.float32] diff --git a/tests/test_hashing.py b/tests/test_hashing.py index 093de47cf9..61b3e7056b 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -20,6 +20,7 @@ class TestPickleHashing(unittest.TestCase): + def test_pickle(self): set_determinism(0) data1 = np.random.rand(10) @@ -45,6 +46,7 @@ def test_pickle(self): class TestJSONHashing(unittest.TestCase): + def test_json(self): data_dict1 = {"b": "str2", "a": "str1"} data_dict2 = {"a": "str1", "b": "str2"} diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 71bbad36d2..20276a1832 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -168,6 +168,7 @@ def _describe_test_case(test_func, test_number, params): class TestHausdorffDistance(unittest.TestCase): + @parameterized.expand(TEST_CASES_EXPANDED, doc_func=_describe_test_case) def test_value(self, device, metric, directed, input_data, expected_value): percentile = None diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index 5ed20f5f3b..f2211008c2 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -198,6 +198,7 @@ def _describe_test_case(test_func, test_number, params): @skipUnless(has_scipy, "Scipy required") class TestHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES, doc_func=_describe_test_case) def test_shape(self, input_param, input_data, expected_val): result = HausdorffDTLoss(**input_param).forward(**input_data) @@ -218,22 +219,18 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): HausdorffDTLoss(reduction=None)(chn_input, chn_target) - def test_input_warnings(self): + @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)]) + def test_input_warnings(self, include_background, softmax, to_onehot_y): chn_input = torch.ones((1, 1, 1, 3)) chn_target = torch.ones((1, 1, 1, 3)) with self.assertWarns(Warning): - loss = HausdorffDTLoss(include_background=False) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = HausdorffDTLoss(softmax=True) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = HausdorffDTLoss(to_onehot_y=True) + loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) loss.forward(chn_input, chn_target) @skipUnless(has_scipy, "Scipy required") class TesLogtHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES_LOG, doc_func=_describe_test_case) def test_shape(self, input_param, input_data, expected_val): result = LogHausdorffDTLoss(**input_param).forward(**input_data) @@ -254,17 +251,12 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): LogHausdorffDTLoss(reduction=None)(chn_input, chn_target) - def test_input_warnings(self): + @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)]) + def test_input_warnings(self, include_background, softmax, to_onehot_y): chn_input = torch.ones((1, 1, 1, 3)) chn_target = torch.ones((1, 1, 1, 3)) with self.assertWarns(Warning): - loss = LogHausdorffDTLoss(include_background=False) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = LogHausdorffDTLoss(softmax=True) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = LogHausdorffDTLoss(to_onehot_y=True) + loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) loss.forward(chn_input, chn_target) diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py index 71fed1e35d..c0ea2a8643 100644 --- a/tests/test_header_correct.py +++ b/tests/test_header_correct.py @@ -20,6 +20,7 @@ class TestCorrection(unittest.TestCase): + def test_correct(self): test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4)) test_img.header.set_zooms((100, 100, 100)) diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 04520419b7..bcc5739900 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -48,6 +48,7 @@ class TestHighResNet(DistTestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): net = HighResNet(**input_param).to(device) diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 68fa0b1192..879a74969d 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -161,6 +161,7 @@ def create_expected_numpy_output(input_datum, **kwargs): @SkipIfNoModule("torch.fft") class TestHilbertTransformCPU(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_1D_SINE_CPU, @@ -179,6 +180,7 @@ def test_value(self, arguments, image, expected_data, atol): @skip_if_no_cuda @SkipIfNoModule("torch.fft") class TestHilbertTransformGPU(unittest.TestCase): + @parameterized.expand( ( [] @@ -201,6 +203,7 @@ def test_value(self, arguments, image, expected_data, atol): @SkipIfModule("torch.fft") class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10)) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index 3a340db52a..25c0afb64d 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -48,6 +48,7 @@ class TestHistogramNormalize(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = HistogramNormalize(**arguments)(image) diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 24f27d225e..a390375441 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -48,6 +48,7 @@ class TestHistogramNormalized(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = HistogramNormalized(**arguments)(image)["img"] diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py index d768895bdc..fb4946b011 100644 --- a/tests/test_hovernet.py +++ b/tests/test_hovernet.py @@ -154,6 +154,7 @@ def check_kernels(net, mode): class TestHoverNet(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shapes): input_param["decoder_padding"] = False diff --git a/tests/test_hovernet_instance_map_post_processing.py b/tests/test_hovernet_instance_map_post_processing.py index 990e2d9a10..ce272fba1a 100644 --- a/tests/test_hovernet_instance_map_post_processing.py +++ b/tests/test_hovernet_instance_map_post_processing.py @@ -42,6 +42,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestHoVerNetInstanceMapPostProcessing(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, kwargs, expected_info, expected_map): nuclear_prediction = in_type(test_data.astype(float)) diff --git a/tests/test_hovernet_instance_map_post_processingd.py b/tests/test_hovernet_instance_map_post_processingd.py index 69e42d3495..c982156caa 100644 --- a/tests/test_hovernet_instance_map_post_processingd.py +++ b/tests/test_hovernet_instance_map_post_processingd.py @@ -43,6 +43,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestHoVerNetInstanceMapPostProcessingd(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, kwargs, expected_info, expected_map): input = { diff --git a/tests/test_hovernet_loss.py b/tests/test_hovernet_loss.py index 10db4518fa..b7cd1f3104 100644 --- a/tests/test_hovernet_loss.py +++ b/tests/test_hovernet_loss.py @@ -35,6 +35,7 @@ class PrepareTestInputs: + def __init__(self, inputs): self.inputs = {HoVerNetBranch.NP: inputs[1], HoVerNetBranch.HV: inputs[3]} self.targets = {HoVerNetBranch.NP: inputs[0], HoVerNetBranch.HV: inputs[2]} @@ -171,6 +172,7 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w class TestHoverNetLoss(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, expected_loss): loss = HoVerNetLoss() diff --git a/tests/test_hovernet_nuclear_type_post_processing.py b/tests/test_hovernet_nuclear_type_post_processing.py index f2b33c96ae..e97b7abd2c 100644 --- a/tests/test_hovernet_nuclear_type_post_processing.py +++ b/tests/test_hovernet_nuclear_type_post_processing.py @@ -41,6 +41,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestHoVerNetNuclearTypePostProcessing(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, kwargs, expected_info, expected_map): nuclear_prediction = in_type(test_data.astype(float)) diff --git a/tests/test_hovernet_nuclear_type_post_processingd.py b/tests/test_hovernet_nuclear_type_post_processingd.py index 01478b7961..26cf80592c 100644 --- a/tests/test_hovernet_nuclear_type_post_processingd.py +++ b/tests/test_hovernet_nuclear_type_post_processingd.py @@ -42,6 +42,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestHoVerNetNuclearTypePostProcessingd(unittest.TestCase): + @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, kwargs, expected): input = { diff --git a/tests/test_identity.py b/tests/test_identity.py index 19116cbb8f..4243a7f19a 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -18,6 +18,7 @@ class TestIdentity(NumpyImageTestCase2D): + def test_identity(self): for p in TEST_NDARRAYS: img = p(self.imt) diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 98499def01..6b81ad9f16 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -18,6 +18,7 @@ class TestIdentityd(NumpyImageTestCase2D): + def test_identityd(self): for p in TEST_NDARRAYS: img = p(self.imt) diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 7f7bdec513..fc8b4b6ccb 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -47,6 +47,7 @@ def __call__(self, data): class _TestCompose(Compose): + def __call__(self, data, meta, lazy): data = self.transforms[0](data) # ensure channel first data = self.transforms[1](data, lazy=lazy) # spacing @@ -57,6 +58,7 @@ def __call__(self, data, meta, lazy): class TestImageDataset(unittest.TestCase): + def test_use_case(self): with tempfile.TemporaryDirectory() as tempdir: img_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)).astype(float), np.eye(4)) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 985ea95e79..adc9dade9c 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -38,6 +38,7 @@ class TestModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -50,6 +51,7 @@ class TestNotAModuleOrTransform: class TestImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) def test_init_from_string(self, filter_name): "Test init from string" @@ -133,6 +135,7 @@ def test_pass_empty_metadata_dict(self): class TestImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) def test_init_from_string_dict(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" @@ -162,6 +165,7 @@ def test_call_3d(self, filter_name): class TestRandImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) def test_init_from_string(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" @@ -205,6 +209,7 @@ def test_call_3d_prob_0(self, filter_name): class TestRandImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) def test_init_from_string_dict(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index 79e51c53eb..7e1c1deecc 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -33,6 +33,7 @@ @unittest.skipUnless(has_itk, "itk not installed") class TestLoadSaveNifti(unittest.TestCase): + def setUp(self): self.test_dir = tempfile.mkdtemp() @@ -97,6 +98,7 @@ def test_4d(self, reader, writer): @unittest.skipUnless(has_itk, "itk not installed") class TestLoadSavePNG(unittest.TestCase): + def setUp(self): self.test_dir = tempfile.mkdtemp() @@ -137,6 +139,7 @@ def test_rgb(self, reader, writer): class TestRegRes(unittest.TestCase): + def test_0_default(self): self.assertTrue(len(resolve_writer(".png")) > 0, "has png writer") self.assertTrue(len(resolve_writer(".nrrd")) > 0, "has nrrd writer") @@ -153,6 +156,7 @@ def test_1_new(self): @unittest.skipUnless(has_itk, "itk not installed") class TestLoadSaveNrrd(unittest.TestCase): + def setUp(self): self.test_dir = tempfile.mkdtemp() diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index 7825f9b4d7..901ca77e7f 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -21,6 +21,7 @@ class TestImg2Tensorboard(unittest.TestCase): + def test_write_gray(self): nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) summary_object_np = make_animated_gif_summary( diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index 1350146220..cb45cb5146 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -19,6 +19,7 @@ class TestInitLoadImage(unittest.TestCase): + def test_load_image(self): instance1 = LoadImage(image_only=False, dtype=None) instance2 = LoadImage(image_only=True, dtype=None) diff --git a/tests/test_integration_autorunner.py b/tests/test_integration_autorunner.py index 7110db568d..31a0813abc 100644 --- a/tests/test_integration_autorunner.py +++ b/tests/test_integration_autorunner.py @@ -71,6 +71,7 @@ @SkipIfBeforePyTorchVersion((1, 11, 1)) # for mem_get_info @unittest.skipIf(not has_tb, "no tensorboard summary writer") class TestAutoRunner(unittest.TestCase): + def setUp(self) -> None: self.test_dir = tempfile.TemporaryDirectory() test_path = self.test_dir.name diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index bd96f50c55..c2e0fb55b7 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -37,6 +37,7 @@ class _Runnable42: + def __init__(self, val=1): self.val = val @@ -46,6 +47,7 @@ def run(self): class _Runnable43: + def __init__(self, func): self.func = func @@ -54,6 +56,7 @@ def run(self): class TestBundleRun(unittest.TestCase): + def setUp(self): self.data_dir = tempfile.mkdtemp() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 4fc92c4068..b137fc9b75 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -45,6 +45,7 @@ class MedNISTDataset(torch.utils.data.Dataset): + def __init__(self, image_files, labels, transforms): self.image_files = image_files self.labels = labels @@ -182,6 +183,7 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10 @skip_if_quick class IntegrationClassification2D(DistTestCase): + def setUp(self): set_determinism(seed=0) self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index 6821279080..3e88f05620 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -26,7 +26,9 @@ def run_test(batch_size=64, train_steps=200, device="cuda:0"): + class _TestBatch(Dataset): + def __init__(self, transforms): self.transforms = transforms @@ -76,6 +78,7 @@ def __len__(self): class TestDeterminism(DistTestCase): + def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index 497fe22dab..071eb5cf78 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -58,6 +58,7 @@ @skip_if_no_cuda @skip_if_quick class IntegrationFastTrain(DistTestCase): + def setUp(self): set_determinism(seed=0) monai.config.print_config() diff --git a/tests/test_integration_gpu_customization.py b/tests/test_integration_gpu_customization.py index 44165b967c..043405a580 100644 --- a/tests/test_integration_gpu_customization.py +++ b/tests/test_integration_gpu_customization.py @@ -70,6 +70,7 @@ @SkipIfBeforePyTorchVersion((1, 11, 1)) # module 'torch.cuda' has no attribute 'mem_get_info' @unittest.skipIf(not has_tb, "no tensorboard summary writer") class TestEnsembleGpuCustomization(unittest.TestCase): + def setUp(self) -> None: self.test_dir = tempfile.TemporaryDirectory() diff --git a/tests/test_integration_lazy_samples.py b/tests/test_integration_lazy_samples.py index c365616bc8..51d80e7305 100644 --- a/tests/test_integration_lazy_samples.py +++ b/tests/test_integration_lazy_samples.py @@ -160,6 +160,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, @skip_if_quick @SkipIfBeforePyTorchVersion((1, 11)) class IntegrationLazyResampling(DistTestCase): + def setUp(self): monai.config.print_config() set_determinism(seed=0) diff --git a/tests/test_integration_nnunetv2_runner.py b/tests/test_integration_nnunetv2_runner.py index d35737f86f..822d454f52 100644 --- a/tests/test_integration_nnunetv2_runner.py +++ b/tests/test_integration_nnunetv2_runner.py @@ -49,6 +49,7 @@ @unittest.skipIf(not has_tb, "no tensorboard summary writer") @unittest.skipIf(not has_nnunet, "no nnunetv2") class TestnnUNetV2Runner(unittest.TestCase): + def setUp(self) -> None: self.test_dir = tempfile.TemporaryDirectory() test_path = self.test_dir.name diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 2e4cc31645..c72369b151 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -235,6 +235,7 @@ def run_inference_test(root_dir, device="cuda:0"): @skip_if_quick class IntegrationSegmentation3D(DistTestCase): + def setUp(self): set_determinism(seed=0) diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index bcc66a687e..8b53e94941 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -72,6 +72,7 @@ def save_func(engine): @skip_if_quick class TestIntegrationSlidingWindow(DistTestCase): + def setUp(self): set_determinism(seed=0) diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py index c858060c31..750a20ea5c 100644 --- a/tests/test_integration_stn.py +++ b/tests/test_integration_stn.py @@ -98,6 +98,7 @@ def compare_2d(is_ref=True, device=None, reverse_indexing=False): class TestSpatialTransformerCore(DistTestCase): + def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index 90c0098d36..918190775c 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -25,7 +25,9 @@ def run_test(net_name="basicunet", batch_size=64, train_steps=100, device="cuda:0"): + class _TestBatch(Dataset): + def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) return im[None], seg[None].astype(np.float32) @@ -54,6 +56,7 @@ def __len__(self): @skip_if_quick class TestIntegrationUnet2D(DistTestCase): + @TimedCall(seconds=20, daemon=False) def test_unet_training(self): for n in ["basicunet", "unet"]: diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py index 33c26cedf8..123b1ddc6f 100644 --- a/tests/test_integration_workers.py +++ b/tests/test_integration_workers.py @@ -44,6 +44,7 @@ def run_loading_test(num_workers=50, device=None, pw=False): @skip_if_no_cuda @SkipIfBeforePyTorchVersion((1, 9)) class IntegrationLoading(DistTestCase): + def tearDown(self): set_determinism(seed=None) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 7c6f35f3d3..fafb66f675 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -118,6 +118,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): ) class _TestEvalIterEvents: + def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) @@ -160,6 +161,7 @@ def _forward_completed(self, engine): ) class _TestTrainIterEvents: + def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) @@ -284,6 +286,7 @@ def save_func(engine): @skip_if_quick class IntegrationWorkflows(DistTestCase): + def setUp(self): set_determinism(seed=0) diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index 6896241d35..1428506020 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -127,6 +127,7 @@ def generator_loss(gen_images): @skip_if_quick class IntegrationWorkflowsGAN(DistTestCase): + def setUp(self): set_determinism(seed=0) diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py index 243fcd0dd4..e45c2acbad 100644 --- a/tests/test_intensity_stats.py +++ b/tests/test_intensity_stats.py @@ -53,6 +53,7 @@ class TestIntensityStats(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, img, meta_dict, expected): _, meta_dict = IntensityStats(**input_param)(img, meta_dict) diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py index 3fe82b1df7..d164f249db 100644 --- a/tests/test_intensity_statsd.py +++ b/tests/test_intensity_statsd.py @@ -52,6 +52,7 @@ class TestIntensityStatsd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, input_param, data, meta_key, expected): meta = IntensityStatsd(**input_param)(data)[meta_key] diff --git a/tests/test_inverse_array.py b/tests/test_inverse_array.py index c0b1a77e55..4da9ee34b9 100644 --- a/tests/test_inverse_array.py +++ b/tests/test_inverse_array.py @@ -33,6 +33,7 @@ @unittest.skipUnless(has_nib, "Requires nibabel") class TestInverseArray(unittest.TestCase): + @staticmethod def get_image(dtype, device) -> MetaTensor: affine = torch.tensor([[0, 0, 1, 0], [-1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 0, 1]]).to(dtype).to(device) diff --git a/tests/test_invert.py b/tests/test_invert.py index 9c57b11331..69d31edfc8 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -41,6 +41,7 @@ class TestInvert(unittest.TestCase): + def test_invert(self): set_determinism(seed=0) im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 2e6ee35981..c32a3af643 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -43,6 +43,7 @@ class TestInvertd(unittest.TestCase): + def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py index 591772bb3a..fb488eb054 100644 --- a/tests/test_is_supported_format.py +++ b/tests/test_is_supported_format.py @@ -33,6 +33,7 @@ class TestIsSupportedFormat(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, input_param, result): self.assertEqual(is_supported_format(**input_param), result) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 38be9ec30c..cfa711e4c0 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -24,6 +24,7 @@ class _Stream: + def __init__(self, data): self.data = data @@ -32,6 +33,7 @@ def __iter__(self): class TestIterableDataset(unittest.TestCase): + def test_shape(self): expected_shape = (128, 128, 128) test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) diff --git a/tests/test_itk_torch_bridge.py b/tests/test_itk_torch_bridge.py index b368230c53..22ae019271 100644 --- a/tests/test_itk_torch_bridge.py +++ b/tests/test_itk_torch_bridge.py @@ -49,6 +49,7 @@ @unittest.skipUnless(has_itk, "Requires `itk` package.") class TestITKTorchAffineMatrixBridge(unittest.TestCase): + def setUp(self): set_determinism(seed=0) self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") @@ -493,6 +494,7 @@ def test_use_reference_space(self, ref_filepath, filepath): @unittest.skipUnless(has_nib, "Requires `nibabel` package.") @skip_if_quick class TestITKTorchRW(unittest.TestCase): + def setUp(self): TestITKTorchAffineMatrixBridge.setUp(self) diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index c9707b1b5a..6625339dd0 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -27,6 +27,7 @@ @unittest.skipUnless(has_itk, "Requires `itk` package.") class TestITKWriter(unittest.TestCase): + def test_channel_shape(self): with tempfile.TemporaryDirectory() as tempdir: for c in (0, 1, 2, 3): diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 4d820573a6..17acedf319 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -32,6 +32,7 @@ class TestKSpaceSpikeNoise(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index 76a79d4b12..ce542af0aa 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -33,6 +33,7 @@ class TestKSpaceSpikeNoised(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 7da3c4b21f..2dfac1142e 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -381,6 +381,7 @@ def to_onehot(x): class TestKeepLargestConnectedComponent(unittest.TestCase): + @parameterized.expand(TESTS) def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index aac91a2de9..4d3172741d 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -337,6 +337,7 @@ class TestKeepLargestConnectedComponentd(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_dict, expected): converter = KeepLargestConnectedComponentd(**args) diff --git a/tests/test_kspace_mask.py b/tests/test_kspace_mask.py index 5d6d9c18ea..cfbd7864c8 100644 --- a/tests/test_kspace_mask.py +++ b/tests/test_kspace_mask.py @@ -26,6 +26,7 @@ class TestMRIUtils(unittest.TestCase): + @parameterized.expand(TESTSM) def test_mask(self, test_data): # random mask diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index 47a8706491..93cf95a2a0 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -58,6 +58,7 @@ class TestLabelFilter(unittest.TestCase): + @parameterized.expand(VALID_TESTS) def test_correct_results(self, _, args, input_image, expected): converter = LabelFilter(**args) diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py index f27df08c2a..fba8100f25 100644 --- a/tests/test_label_filterd.py +++ b/tests/test_label_filterd.py @@ -58,6 +58,7 @@ class TestLabelFilter(unittest.TestCase): + @parameterized.expand(VALID_TESTS) def test_correct_results(self, _, args, input_image, expected): converter = LabelFilterd(keys="image", **args) diff --git a/tests/test_label_quality_score.py b/tests/test_label_quality_score.py index aa243b4236..a46b78b1d4 100644 --- a/tests/test_label_quality_score.py +++ b/tests/test_label_quality_score.py @@ -99,6 +99,7 @@ class TestLabelQualityScore(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, input_data, expected_value): result = label_quality_score(**input_data) diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index 590fd5d4e4..d7fbfc9b8d 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -142,6 +142,7 @@ def gen_fixed_img(array_type): class TestContour(unittest.TestCase): + def test_contour(self): input_param = {"kernel_type": "Laplace"} diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index 6fcec72dd8..a91a712da6 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -143,6 +143,7 @@ def gen_fixed_img(array_type): class TestContourd(unittest.TestCase): + def test_contour(self): input_param = {"keys": "img", "kernel_type": "Laplace"} diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 2eba825cf3..47a58cc989 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -59,6 +59,7 @@ class TestLabelToMask(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = LabelToMask(**arguments)(image) diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index 35f54ca5b9..44b537128d 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -59,6 +59,7 @@ class TestLabelToMaskd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, input_data, expected_data): result = LabelToMaskd(**arguments)(input_data) diff --git a/tests/test_lambda.py b/tests/test_lambda.py index e2276d671c..e0a5cf84db 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -23,6 +23,7 @@ class TestLambda(NumpyImageTestCase2D): + def test_lambda_identity(self): for p in TEST_NDARRAYS: img = p(self.imt) diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 02e4423b74..fad5ebeee4 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -23,6 +23,7 @@ class TestLambdad(NumpyImageTestCase2D): + def test_lambdad_identity(self): for p in TEST_NDARRAYS: img = p(self.imt) diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py index 10682c2bb7..0622809102 100644 --- a/tests/test_lesion_froc.py +++ b/tests/test_lesion_froc.py @@ -298,6 +298,7 @@ def prepare_test_data(): class TestEvaluateTumorFROC(unittest.TestCase): + @skipUnless(has_cucim, "Requires cucim") @skipUnless(has_skimage, "Requires skimage") @skipUnless(has_sp, "Requires scipy") diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py index 9be61e3999..56ee040758 100644 --- a/tests/test_list_data_collate.py +++ b/tests/test_list_data_collate.py @@ -37,6 +37,7 @@ class TestListDataCollate(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_type_shape(self, input_data, expected_type, expected_shape): result = list_data_collate(input_data) diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py index 4e6bb8cdf7..abb61ea182 100644 --- a/tests/test_list_to_dict.py +++ b/tests/test_list_to_dict.py @@ -32,6 +32,7 @@ class TestListToDict(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value_shape(self, input, output): result = list_to_dict(input) diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 6ee716e1ef..cc64672e77 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -29,6 +29,7 @@ class TestLLTM(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) @SkipIfNoModule("monai._C") def test_value(self, input_param, expected_h, expected_c): diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 155b4eb0fc..9d128dd728 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -81,6 +81,7 @@ class _InplaceXform(Transform): + def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -91,6 +92,7 @@ def __call__(self, data): @skip_if_windows class TestLMDBDataset(unittest.TestCase): + def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] diff --git a/tests/test_lmdbdataset_dist.py b/tests/test_lmdbdataset_dist.py index 0b4c7c35fa..1acb89beb3 100644 --- a/tests/test_lmdbdataset_dist.py +++ b/tests/test_lmdbdataset_dist.py @@ -23,6 +23,7 @@ class _InplaceXform(Transform): + def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -33,6 +34,7 @@ def __call__(self, data): @skip_if_windows class TestMPLMDBDataset(DistTestCase): + def setUp(self): self.tempdir = tempfile.mkdtemp() diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index b0e390cd73..7281034498 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -21,6 +21,7 @@ class TestLoadDecathlonDatalist(unittest.TestCase): + def test_seg_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data = { diff --git a/tests/test_load_image.py b/tests/test_load_image.py index b6a10bceb4..0207079d7d 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -160,6 +160,7 @@ def get_data(self, _obj): @unittest.skipUnless(has_itk, "itk not installed") class TestLoadImage(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() @@ -379,6 +380,7 @@ def test_channel_dim(self, input_param, filename, expected_shape): @unittest.skipUnless(has_itk, "itk not installed") class TestLoadImageMeta(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 534cbb6618..699ed70059 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -46,6 +46,7 @@ @unittest.skipUnless(has_itk, "itk not installed") class TestLoadImaged(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) @@ -94,6 +95,7 @@ def test_no_file(self): @unittest.skipUnless(has_itk, "itk not installed") class TestConsistency(unittest.TestCase): + def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() @@ -155,6 +157,7 @@ def test_png(self): @unittest.skipUnless(has_itk, "itk not installed") class TestLoadImagedMeta(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index a121bdd3cd..63422761ca 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -30,6 +30,7 @@ class TestLoadSpacingOrientation(unittest.TestCase): + @staticmethod def load_image(filename): data = {"image": filename} diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py index 859ee1f8d5..83557d830d 100644 --- a/tests/test_loader_semaphore.py +++ b/tests/test_loader_semaphore.py @@ -39,6 +39,7 @@ def _run_test(): class TestImportLock(unittest.TestCase): + def test_start(self): _run_test() diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 21fe8b973f..35a24cd0ca 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -117,6 +117,7 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) diff --git a/tests/test_localnet.py b/tests/test_localnet.py index f557147960..97aa94d2c5 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -62,6 +62,7 @@ class TestLocalNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = LocalNet(**input_param).to(device) diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 27ea4cd1a6..340a8e94ba 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -48,6 +48,7 @@ class TestLocalNetDownSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) def test_shape(self, input_param): net = LocalNetDownSampleBlock(**input_param) @@ -74,6 +75,7 @@ def test_ill_shape(self, input_param): class TestLocalNetUpSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UP_SAMPLE) def test_shape(self, input_param): net = LocalNetUpSampleBlock(**input_param) @@ -100,6 +102,7 @@ def test_ill_shape(self, input_param): class TestExtractBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_EXTRACT) def test_shape(self, input_param): net = LocalNetFeatureExtractorBlock(**input_param) diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index 5f81fb8d43..d40b7eaa8c 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -44,6 +44,7 @@ class _CaseStrEnum(StrEnum): class TestLookUpOption(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_look_up(self, input_str, supported, expected): output = look_up_option(input_str, supported) diff --git a/tests/test_loss_metric.py b/tests/test_loss_metric.py index 682221f5f5..365dc10670 100644 --- a/tests/test_loss_metric.py +++ b/tests/test_loss_metric.py @@ -36,6 +36,7 @@ class TestComputeLossMetric(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_value_class(self, input_data, expected_value): loss_fn = input_data["loss_class"](**input_data["loss_kwargs"]) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 46375890eb..d26cb23a90 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -48,6 +48,7 @@ @unittest.skipUnless(sys.platform == "linux", "requires linux") @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): + def setUp(self): self.root_dir = MONAIEnvVars.data_dir() if not self.root_dir: diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 54092ba931..1a61796fe0 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -20,6 +20,7 @@ class SchedulerTestNet(torch.nn.Module): + def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) @@ -43,6 +44,7 @@ def forward(self, x): class TestLRSCHEDULER(unittest.TestCase): + @parameterized.expand(TEST_CASE_LRSCHEDULER) def test_shape(self, input_param, expected_lr): net = SchedulerTestNet() diff --git a/tests/test_make_nifti.py b/tests/test_make_nifti.py index 4560507c6c..08d3a731ab 100644 --- a/tests/test_make_nifti.py +++ b/tests/test_make_nifti.py @@ -34,6 +34,7 @@ @unittest.skipUnless(has_nib, "Requires nibabel") class TestMakeNifti(unittest.TestCase): + @parameterized.expand(TESTS) def test_make_nifti(self, params): im, _ = create_test_image_2d(100, 88) diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index 1080c2a513..9931d997bb 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -64,6 +64,7 @@ class TestMapBinaryToIndices(unittest.TestCase): + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_fg, expected_bg): fg_indices, bg_indices = map_binary_to_indices(**input_data) diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 9c8b4b4793..902744ab65 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -124,6 +124,7 @@ class TestMapClassesToIndices(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index 6b8121b6df..cd311df6bd 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -75,6 +75,7 @@ class TestMapLabelValue(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValue(**input_param)(input_data) diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index fa0d094393..0fb46f2515 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -69,6 +69,7 @@ class TestMapLabelValued(unittest.TestCase): + @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_5_1, TEST_CASE_6, TEST_CASE_7] ) diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py index 7430cf09c7..a7be7b9f5d 100644 --- a/tests/test_map_transform.py +++ b/tests/test_map_transform.py @@ -23,11 +23,13 @@ class MapTest(MapTransform): + def __call__(self, data): pass class TestRandomizable(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_keys(self, keys, expected): transform = MapTest(keys=keys) diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 2b831ba415..b7ff324946 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -55,6 +55,7 @@ class TestMaskIntensity(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, arguments, image, expected_data): for p in TEST_NDARRAYS: diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index 6a39416de4..0efd1f835f 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -57,6 +57,7 @@ class TestMaskIntensityd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, arguments, image, expected_data): result = MaskIntensityd(**arguments)(image) diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index b868f4d3a1..c971723615 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -113,6 +113,7 @@ class TestDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = MaskedDiceLoss(**input_param).forward(**input_data) diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 708d507523..3c04ffadcb 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -40,6 +40,7 @@ class TestMaskedLoss(unittest.TestCase): + def setUp(self): set_determinism(0) diff --git a/tests/test_masked_patch_wsi_dataset.py b/tests/test_masked_patch_wsi_dataset.py index 35509b32f6..8d24075595 100644 --- a/tests/test_masked_patch_wsi_dataset.py +++ b/tests/test_masked_patch_wsi_dataset.py @@ -74,6 +74,7 @@ def setUpModule(): class MaskedPatchWSIDatasetTests: + class Tests(unittest.TestCase): backend = None @@ -100,6 +101,7 @@ def test_gen_patches(self, input_parameters, expected): @skipUnless(has_cucim, "Requires cucim") class TestSlidingPatchWSIDatasetCuCIM(MaskedPatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "cucim" @@ -107,6 +109,7 @@ def setUpClass(cls): @skipUnless(has_osl, "Requires openslide") class TestSlidingPatchWSIDatasetOpenSlide(MaskedPatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "openslide" diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index a6cb3fcee3..e513025e69 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -35,6 +35,7 @@ @SkipIfNoModule("matplotlib") class TestMatshow3d(unittest.TestCase): + def test_3d(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") keys = "image" diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 09b7f94dc4..6b463f8530 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -58,6 +58,7 @@ class TestMeanEnsemble(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = MeanEnsemble(**input_param)(img) diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index 01123b0729..795ae47368 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -72,6 +72,7 @@ class TestMeanEnsembled(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, data, expected_value): result = MeanEnsembled(**input_param)(data) diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py index 9f27adff4c..516388afce 100644 --- a/tests/test_median_filter.py +++ b/tests/test_median_filter.py @@ -15,26 +15,20 @@ import numpy as np import torch +from parameterized import parameterized from monai.networks.layers import MedianFilter class MedianFilterTestCase(unittest.TestCase): - def test_3d_big(self): - a = torch.ones(1, 1, 2, 3, 5) - g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0")) + @parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d + def test_3d(self, input_tensor, radius): + filter = MedianFilter(radius).to(torch.device("cpu:0")) - expected = a.numpy() - out = g(a).cpu().numpy() - np.testing.assert_allclose(out, expected, rtol=1e-5) - - def test_3d(self): - a = torch.ones(1, 1, 4, 3, 4) - g = MedianFilter(1).to(torch.device("cpu:0")) + expected = input_tensor.numpy() + output = filter(input_tensor).cpu().numpy() - expected = a.numpy() - out = g(a).cpu().numpy() - np.testing.assert_allclose(out, expected, rtol=1e-5) + np.testing.assert_allclose(output, expected, rtol=1e-5) def test_3d_radii(self): a = torch.ones(1, 1, 4, 3, 2) diff --git a/tests/test_median_smooth.py b/tests/test_median_smooth.py index 21cd45f28e..5930c0c6b6 100644 --- a/tests/test_median_smooth.py +++ b/tests/test_median_smooth.py @@ -31,6 +31,7 @@ class TestMedianSmooth(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = MedianSmooth(**arguments)(image) diff --git a/tests/test_median_smoothd.py b/tests/test_median_smoothd.py index b8d3452c86..e0bdb331c8 100644 --- a/tests/test_median_smoothd.py +++ b/tests/test_median_smoothd.py @@ -55,6 +55,7 @@ class TestMedianSmoothd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): result = MedianSmoothd(**arguments)(image) diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index b5a809ccaa..1db632c144 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -25,6 +25,7 @@ class TestMedNISTDataset(unittest.TestCase): + @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") @@ -64,11 +65,8 @@ def _test_dataset(dataset): self.assertEqual(data[0]["class_name"], "AbdomenCT") self.assertEqual(data[0]["label"], 0) shutil.rmtree(os.path.join(testing_dir, "MedNIST")) - try: + with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"): MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) - except RuntimeError as e: - print(str(e)) - self.assertTrue(str(e).startswith("Cannot find dataset directory")) if __name__ == "__main__": diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py index b95ea3f1ac..95764a0c89 100644 --- a/tests/test_meta_affine.py +++ b/tests/test_meta_affine.py @@ -123,6 +123,7 @@ def _resample_to_affine(itk_obj, ref_obj): @unittest.skipUnless(has_itk, "Requires itk package.") class TestAffineConsistencyITK(unittest.TestCase): + @classmethod def setUpClass(cls): super().setUpClass() diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 0cd0522036..1e0f188b63 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -50,6 +50,7 @@ def rand_string(min_len=5, max_len=10): class TestMetaTensor(unittest.TestCase): + @staticmethod def get_im(shape=None, dtype=None, device=None): if shape is None: diff --git a/tests/test_metatensor_integration.py b/tests/test_metatensor_integration.py index 6a4c67d160..d647e47e74 100644 --- a/tests/test_metatensor_integration.py +++ b/tests/test_metatensor_integration.py @@ -39,6 +39,7 @@ @unittest.skipUnless(has_nib, "Requires nibabel package.") class TestMetaTensorIntegration(unittest.TestCase): + @classmethod def setUpClass(cls): super().setUpClass() diff --git a/tests/test_metrics_reloaded.py b/tests/test_metrics_reloaded.py index 010326b87d..562693c07c 100644 --- a/tests/test_metrics_reloaded.py +++ b/tests/test_metrics_reloaded.py @@ -76,6 +76,7 @@ @unittest.skipIf(not has_metrics, "MetricsReloaded not available.") class TestMetricsReloaded(unittest.TestCase): + @parameterized.expand(TEST_CASES_BINARY) def test_binary(self, input_param, input_data, expected_val): metric = MetricsReloadedBinary(**input_param) diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py index 9178e0bccb..42116e8220 100644 --- a/tests/test_milmodel.py +++ b/tests/test_milmodel.py @@ -63,6 +63,7 @@ class TestMilModel(unittest.TestCase): + @parameterized.expand(TEST_CASE_MILMODEL) def test_shape(self, input_param, input_shape, expected_shape): with skip_if_downloading_fails(): diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 8ad66ebc6e..54f70d3318 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -33,6 +33,7 @@ class TestMLPBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_MLP) def test_shape(self, input_param, input_shape, expected_shape): net = MLPBlock(**input_param) diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 66fca6bb7f..6af3d09fb2 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -116,6 +116,7 @@ @unittest.skip("deprecating mmar tests") class TestMMMARDownload(unittest.TestCase): + @parameterized.expand(TEST_CASES) @skip_if_quick def test_download(self, idx): diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 293da95d5a..d21ba53b7c 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -21,6 +21,7 @@ class TestAllImport(unittest.TestCase): + def test_public_api(self): """ This is to check "monai.__all__" should be consistent with diff --git a/tests/test_monai_env_vars.py b/tests/test_monai_env_vars.py index 6e0d6f0ddf..f5ef28a0ac 100644 --- a/tests/test_monai_env_vars.py +++ b/tests/test_monai_env_vars.py @@ -18,6 +18,7 @@ class TestMONAIEnvVars(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_monai_utils_misc.py b/tests/test_monai_utils_misc.py index 742c9e4047..f4eb5d3956 100644 --- a/tests/test_monai_utils_misc.py +++ b/tests/test_monai_utils_misc.py @@ -40,11 +40,13 @@ class MiscClass: + def __init__(self, arg1, arg2, kwargs1=None, kwargs2=None): pass class TestToTupleOfDictionaries(unittest.TestCase): + @parameterized.expand(TO_TUPLE_OF_DICTIONARIES_TEST_CASES) def test_to_tuple_of_dictionaries(self, dictionary, keys, expected): self._test_to_tuple_of_dictionaries(dictionary, keys, expected) @@ -61,6 +63,7 @@ def _test_to_tuple_of_dictionaries(self, dictionary, keys, expected): class TestMiscKwargs(unittest.TestCase): + def test_kwargs(self): present, extra_args = self._custom_user_function(MiscClass, 1, kwargs1="value1", kwargs2="value2") self.assertEqual(present, True) @@ -74,6 +77,7 @@ def _custom_user_function(self, cls, *args, **kwargs): class TestCommandRunner(unittest.TestCase): + def setUp(self): self.orig_flag = str(MONAIEnvVars.debug()) @@ -88,12 +92,11 @@ def test_run_cmd(self): cmd2 = "-c" cmd3 = 'import sys; print("\\tThis is on stderr\\n", file=sys.stderr); sys.exit(1)' os.environ["MONAI_DEBUG"] = str(True) - try: + with self.assertRaises(RuntimeError) as cm: run_cmd([cmd1, cmd2, cmd3], check=True) - except RuntimeError as err: - self.assertIn("This is on stderr", str(err)) - self.assertNotIn("\\n", str(err)) - self.assertNotIn("\\t", str(err)) + self.assertIn("This is on stderr", str(cm.exception)) + self.assertNotIn("\\n", str(cm.exception)) + self.assertNotIn("\\t", str(cm.exception)) if __name__ == "__main__": diff --git a/tests/test_mri_utils.py b/tests/test_mri_utils.py index 2f67816e2e..aabf06d02e 100644 --- a/tests/test_mri_utils.py +++ b/tests/test_mri_utils.py @@ -27,6 +27,7 @@ class TestMRIUtils(unittest.TestCase): + @parameterized.expand(TESTS) def test_rss(self, test_data, res_data): result = root_sum_of_squares(test_data, spatial_dim=1) diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 8b8acb2503..0b49087216 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -52,22 +52,30 @@ class TestMultiScale(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = MultiScaleLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) - def test_ill_opts(self): - with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, kernel="none") - with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, scales=[-1])( - torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device) - ) - with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")( - torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device) - ) + @parameterized.expand( + [ + ({"loss": dice_loss, "kernel": "none"}, None, None), # kernel_none + ({"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))), # scales_negative + ( + {"loss": dice_loss, "scales": [-1], "reduction": "none"}, + torch.ones((1, 1, 3)), + torch.ones((1, 1, 3)), + ), # scales_negative_reduction_none + ] + ) + def test_ill_opts(self, kwargs, input, target): + if input is None and target is None: + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(**kwargs) + else: + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(**kwargs)(input, target) def test_script(self): input_param, input_data, expected_val = TEST_CASES[0] diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index 74a2daab9d..242326e242 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -42,6 +42,7 @@ class TestNetAdapter(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): spatial_dims = input_param["dim"] diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index aca145a03d..4182501808 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -38,6 +38,7 @@ class TestNetworkConsistency(unittest.TestCase): + def setUp(self): set_determinism(0) diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index 2539d95fd5..4475d8aaab 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -46,6 +46,7 @@ class TestNiftiEndianness(unittest.TestCase): + def setUp(self): self.im, _ = create_test_image_2d(100, 100) self.fname = tempfile.NamedTemporaryFile(suffix=".nii.gz").name diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py index 3d000160e1..411c783fb5 100644 --- a/tests/test_nifti_header_revise.py +++ b/tests/test_nifti_header_revise.py @@ -20,6 +20,7 @@ class TestRectifyHeaderSformQform(unittest.TestCase): + def test_revise_q(self): img = nib.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4)) img.header.set_zooms((0.1, 0.2, 0.3)) diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index f45c2ac5a7..8543fcea30 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -72,6 +72,7 @@ class TestNiftiLoadRead(unittest.TestCase): + @parameterized.expand(TESTS) def test_orientation(self, array, affine, reader_param, expected): test_image = make_nifti_image(array, affine) diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 193b5cc4b2..72ebf579e1 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -83,6 +83,7 @@ class TestNormalizeIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_default(self, im_type): im = im_type(self.imt.copy()) diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index 451269b1c4..229dcd00ff 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -51,6 +51,7 @@ class TestNormalizeIntensityd(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_image_normalize_intensityd(self, im_type): key = "img" diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py index 4ff4577b72..e2196f1907 100644 --- a/tests/test_npzdictitemdataset.py +++ b/tests/test_npzdictitemdataset.py @@ -21,6 +21,7 @@ class TestNPZDictItemDataset(unittest.TestCase): + def test_load_stream(self): dat0 = np.random.rand(10, 1, 4, 4) dat1 = np.random.rand(10, 1, 4, 4) diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py index 01fabe65a8..649b9fa94d 100644 --- a/tests/test_nrrd_reader.py +++ b/tests/test_nrrd_reader.py @@ -48,6 +48,7 @@ @skipUnless(has_nrrd, "nrrd required") class TestNrrdReader(unittest.TestCase): + def test_verify_suffix(self): reader = NrrdReader() self.assertFalse(reader.verify_suffix("test_image.nrd")) diff --git a/tests/test_nuclick_transforms.py b/tests/test_nuclick_transforms.py index fcdd362b01..a6e66c3658 100644 --- a/tests/test_nuclick_transforms.py +++ b/tests/test_nuclick_transforms.py @@ -179,6 +179,7 @@ class TestFilterImaged(unittest.TestCase): + @parameterized.expand([FILTER_IMAGE_TEST_CASE_1]) def test_correct_shape(self, arguments, input_data, expected_shape): result = FilterImaged(**arguments)(input_data) @@ -186,6 +187,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape): class TestFlattenLabeld(unittest.TestCase): + @parameterized.expand([FLATTEN_LABEL_TEST_CASE_1, FLATTEN_LABEL_TEST_CASE_2, FLATTEN_LABEL_TEST_CASE_3]) def test_correct_num_labels(self, arguments, input_data, expected_result): result = FlattenLabeld(**arguments)(input_data) @@ -193,6 +195,7 @@ def test_correct_num_labels(self, arguments, input_data, expected_result): class TestExtractPatchd(unittest.TestCase): + @parameterized.expand([EXTRACT_TEST_CASE_1, EXTRACT_TEST_CASE_2, EXTRACT_TEST_CASE_3]) def test_correct_patch_size(self, arguments, input_data, expected_shape): result = ExtractPatchd(**arguments)(input_data) @@ -205,6 +208,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestSplitLabelsd(unittest.TestCase): + @parameterized.expand([SPLIT_TEST_CASE_1, SPLIT_TEST_CASE_2]) def test_correct_results(self, arguments, input_data, expected_result): result = SplitLabeld(**arguments)(input_data) @@ -212,6 +216,7 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestGuidanceSignal(unittest.TestCase): + @parameterized.expand([GUIDANCE_TEST_CASE_1, GUIDANCE_TEST_CASE_2]) def test_correct_shape(self, arguments, input_data, expected_shape): result = AddPointGuidanceSignald(**arguments)(input_data) @@ -219,6 +224,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape): class TestClickSignal(unittest.TestCase): + @parameterized.expand([CLICK_TEST_CASE_1, CLICK_TEST_CASE_2]) def test_correct_shape(self, arguments, input_data, expected_shape): result = AddClickSignalsd(**arguments)(input_data) @@ -226,6 +232,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape): class TestPostFilterLabel(unittest.TestCase): + @parameterized.expand([LABEL_FILTER_TEST_CASE_1]) def test_correct_shape(self, arguments, input_data, expected_shape): result = PostFilterLabeld(**arguments)(input_data) @@ -233,6 +240,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape): class TestAddLabelAsGuidance(unittest.TestCase): + @parameterized.expand([LABEL_GUIDANCE_TEST_CASE_1]) def test_correct_shape(self, arguments, input_data, expected_shape): result = AddLabelAsGuidanced(**arguments)(input_data) @@ -240,6 +248,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape): class TestSetLabelClass(unittest.TestCase): + @parameterized.expand([LABEL_CLASS_TEST_CASE_1]) def test_correct_results(self, arguments, input_data, expected_result): result = SetLabelClassd(**arguments)(input_data) diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index eeff2922ad..6303598bb7 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -24,6 +24,7 @@ class TestNumpyReader(unittest.TestCase): + def test_npy(self): test_data = np.random.randint(0, 256, size=[3, 4, 4]) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index 574fd49592..efd2906972 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -72,6 +72,7 @@ @unittest.skipUnless(has_nvtx, "Required torch._C._nvtx for NVTX Range!") class TestNVTXRangeDecorator(unittest.TestCase): + @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1]) def test_tranform_array(self, input): transforms = Compose([Range("random flip")(Flip()), Range()(ToTensor())]) diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py index 3a5314c35f..af15c53d1b 100644 --- a/tests/test_nvtx_transform.py +++ b/tests/test_nvtx_transform.py @@ -43,6 +43,7 @@ class TestNVTXTransforms(unittest.TestCase): + @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1, TEST_CASE_DICT_0, TEST_CASE_DICT_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_nvtx_transfroms_alone(self, input): diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index c7ac5ef533..d821c9bcd9 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -22,6 +22,7 @@ class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): if adjoint_info != 42: raise ValueError @@ -104,6 +105,7 @@ def __call__(self, x, adjoint_info): class TestComputeOcclusionSensitivity(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape): occ_sens = OcclusionSensitivity(**init_data) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 2909597507..ecf1cb3319 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -39,31 +39,37 @@ class X(Transform): + def __call__(self, x): return x class Y(Transform): + def __call__(self, x): return x class A(Transform): + def __call__(self, x): return x + 1 class B(Transform): + def __call__(self, x): return x + 2 class C(Transform): + def __call__(self, x): return x + 3 class MapBase(MapTransform): + def __init__(self, keys): super().__init__(keys) self.fwd_fn, self.inv_fn = None, None @@ -76,12 +82,14 @@ def __call__(self, data): class NonInv(MapBase): + def __init__(self, keys): super().__init__(keys) self.fwd_fn = lambda x: x * 2 class Inv(MapBase, InvertibleTransform): + def __call__(self, data): d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -98,6 +106,7 @@ def inverse(self, data): class InvA(Inv): + def __init__(self, keys): super().__init__(keys) self.fwd_fn = lambda x: x + 1 @@ -105,6 +114,7 @@ def __init__(self, keys): class InvB(Inv): + def __init__(self, keys): super().__init__(keys) self.fwd_fn = lambda x: x + 100 @@ -123,6 +133,7 @@ def __init__(self, keys): class TestOneOf(unittest.TestCase): + @parameterized.expand(TESTS) def test_normalize_weights(self, transforms, input_weights, expected_weights): tr = OneOf(transforms, input_weights) @@ -240,6 +251,7 @@ def test_one_of(self): class TestOneOfAPITests(unittest.TestCase): + @staticmethod def data_from_keys(keys): if keys is None: diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py index 03db7b3fc6..2f640f88d0 100644 --- a/tests/test_optional_import.py +++ b/tests/test_optional_import.py @@ -13,21 +13,20 @@ import unittest +from parameterized import parameterized + from monai.utils import OptionalImportError, exact_version, optional_import class TestOptionalImport(unittest.TestCase): - def test_default(self): - my_module, flag = optional_import("not_a_module") + + @parameterized.expand(["not_a_module", "torch.randint"]) + def test_default(self, import_module): + my_module, flag = optional_import(import_module) self.assertFalse(flag) with self.assertRaises(OptionalImportError): my_module.test - my_module, flag = optional_import("torch.randint") - with self.assertRaises(OptionalImportError): - self.assertFalse(flag) - print(my_module.test) - def test_import_valid(self): my_module, flag = optional_import("torch") self.assertTrue(flag) @@ -46,18 +45,9 @@ def test_import_wrong_number(self): self.assertTrue(flag) print(my_module.randint(1, 2, (1, 2))) - def test_import_good_number(self): - my_module, flag = optional_import("torch", "0") - my_module.nn - self.assertTrue(flag) - print(my_module.randint(1, 2, (1, 2))) - - my_module, flag = optional_import("torch", "0.0.0.1") - my_module.nn - self.assertTrue(flag) - print(my_module.randint(1, 2, (1, 2))) - - my_module, flag = optional_import("torch", "1.1.0") + @parameterized.expand(["0", "0.0.0.1", "1.1.0"]) + def test_import_good_number(self, version_number): + my_module, flag = optional_import("torch", version_number) my_module.nn self.assertTrue(flag) print(my_module.randint(1, 2, (1, 2))) diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py index 824793f927..39c0a57877 100644 --- a/tests/test_ori_ras_lps.py +++ b/tests/test_ori_ras_lps.py @@ -38,6 +38,7 @@ class TestITKWriter(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE) def test_ras_to_lps(self, param, expected): assert_allclose(orientation_ras_lps(param), expected) diff --git a/tests/test_orientation.py b/tests/test_orientation.py index aa1c326bdf..2f3334e622 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -177,6 +177,7 @@ class TestOrientationCase(unittest.TestCase): + @parameterized.expand(TESTS) def test_ornt_meta( self, diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index cf4eb23d42..b885266c69 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -65,6 +65,7 @@ class TestOrientationdCase(unittest.TestCase): + @parameterized.expand(TESTS) def test_orntd( self, init_param, img: torch.Tensor, affine: torch.Tensor | None, expected_shape, expected_code, device diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index db9e9c284d..1a4ea6c884 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -62,6 +62,7 @@ class TestP3D(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_3d(self, input_param, input_shape, expected_shape): net = P3DActiConvNormBlock(**input_param) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index cd98f29abf..17f49611df 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -60,6 +60,7 @@ def _testing_collate(x): class _Dataset(torch.utils.data.Dataset): + def __init__(self, images, labels, transforms): self.images = images self.labels = labels @@ -73,6 +74,7 @@ def __getitem__(self, index): class TestPadCollation(unittest.TestCase): + def setUp(self) -> None: set_determinism(seed=0) # image is non square to throw rotation errors @@ -115,7 +117,7 @@ def test_pad_collation(self, t_type, collate_method, transform): batch_inverse = BatchInverseTransform(dataset.transform, loader) for data in loader: output = batch_inverse(data) - self.assertTrue(output[0]["image"].shape, (1, 10, 9)) + self.assertEqual(output[0]["image"].shape, (1, 10, 9)) if __name__ == "__main__": diff --git a/tests/test_pad_mode.py b/tests/test_pad_mode.py index 722d5b573f..54ee2c6d75 100644 --- a/tests/test_pad_mode.py +++ b/tests/test_pad_mode.py @@ -23,6 +23,7 @@ @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestPadMode(unittest.TestCase): + def test_pad(self): expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)} for t in (float, int, np.uint8, np.int16, np.float32, bool): diff --git a/tests/test_partition_dataset.py b/tests/test_partition_dataset.py index 8640d8cc73..c93a6c7682 100644 --- a/tests/test_partition_dataset.py +++ b/tests/test_partition_dataset.py @@ -118,6 +118,7 @@ class TestPartitionDataset(unittest.TestCase): + @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index c4fa5ed199..4c13b2f463 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -76,6 +76,7 @@ class TestPartitionDatasetClasses(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, input_param, result): self.assertListEqual(partition_dataset_classes(**input_param), result) diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index eb705f0c61..9a81d84363 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -27,6 +27,7 @@ def identity(x): class TestPatchDataset(unittest.TestCase): + def test_shape(self): test_dataset = ["vwxyz", "hello", "world"] n_per_image = len(test_dataset[0]) diff --git a/tests/test_patch_inferer.py b/tests/test_patch_inferer.py index 032d22bb98..c6308224b0 100644 --- a/tests/test_patch_inferer.py +++ b/tests/test_patch_inferer.py @@ -245,6 +245,7 @@ class PatchInfererTests(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_0_TENSOR, diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py index cb9ebcf7e3..70e01eaaf4 100644 --- a/tests/test_patch_wsi_dataset.py +++ b/tests/test_patch_wsi_dataset.py @@ -128,6 +128,7 @@ def setUpModule(): class PatchWSIDatasetTests: + class Tests(unittest.TestCase): backend = None @@ -182,6 +183,7 @@ def test_read_patches_str_multi(self, input_parameters, expected): @skipUnless(has_cim, "Requires cucim") class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "cucim" @@ -189,6 +191,7 @@ def setUpClass(cls): @skipUnless(has_osl, "Requires openslide") class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "openslide" diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 77ade984eb..d059145033 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -77,6 +77,7 @@ @SkipIfBeforePyTorchVersion((1, 11, 1)) class TestPatchEmbeddingBlock(unittest.TestCase): + def setUp(self): self.threads = torch.get_num_threads() torch.set_num_threads(4) @@ -92,6 +93,32 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_sincos_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="sincos", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + + def test_learnable_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="learnable", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, True) + def test_ill_arg(self): with self.assertRaises(ValueError): PatchEmbeddingBlock( @@ -162,6 +189,7 @@ def test_ill_arg(self): class TestPatchEmbed(unittest.TestCase): + def setUp(self): self.threads = torch.get_num_threads() torch.set_num_threads(4) diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py index 7ddad4ad6f..26941c6abb 100644 --- a/tests/test_pathology_he_stain.py +++ b/tests/test_pathology_he_stain.py @@ -73,6 +73,7 @@ class TestExtractHEStains(unittest.TestCase): + @parameterized.expand( [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1] ) @@ -145,6 +146,7 @@ def test_result_value(self, image, expected_data): class TestNormalizeHEStains(unittest.TestCase): + @parameterized.expand( [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1] ) diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py index 07db1c3e48..975dc4ffb8 100644 --- a/tests/test_pathology_he_stain_dict.py +++ b/tests/test_pathology_he_stain_dict.py @@ -67,6 +67,7 @@ class TestExtractHEStainsD(unittest.TestCase): + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1]) def test_transparent_image(self, image): """ @@ -140,6 +141,7 @@ def test_result_value(self, image, expected_data): class TestNormalizeHEStainsD(unittest.TestCase): + @parameterized.expand([NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1]) def test_transparent_image(self, image): """ diff --git a/tests/test_pathology_prob_nms.py b/tests/test_pathology_prob_nms.py index 0053500437..b3d7da2c1d 100644 --- a/tests/test_pathology_prob_nms.py +++ b/tests/test_pathology_prob_nms.py @@ -43,6 +43,7 @@ class TestPathologyProbNMS(unittest.TestCase): + @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output(self, class_args, call_args, probs_map, expected): nms = PathologyProbNMS(**class_args) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 7e4860e7f9..b8aa2e5982 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -18,7 +18,7 @@ from monai.losses import PerceptualLoss from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick _, has_torchvision = optional_import("torchvision") TEST_CASES = [ @@ -40,6 +40,31 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], + [ + { + "spatial_dims": 3, + "network_type": "medicalnet_resnet10_23datasets", + "is_fake_3d": False, + "channel_wise": True, + }, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False}, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], [ {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, (2, 1, 64, 64, 64), @@ -52,12 +77,17 @@ @unittest.skipUnless(has_torchvision, "Requires torchvision") @skip_if_quick class TestPerceptualLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, target_shape): with skip_if_downloading_fails(): loss = PerceptualLoss(**input_param) result = loss(torch.randn(input_shape), torch.randn(target_shape)) - self.assertEqual(result.shape, torch.Size([])) + + if "channel_wise" in input_param.keys() and input_param["channel_wise"]: + self.assertEqual(result.shape, torch.Size([input_shape[1]])) + else: + self.assertEqual(result.shape, torch.Size([])) @parameterized.expand(TEST_CASES) def test_identical_input(self, input_param, input_shape, target_shape): @@ -65,7 +95,11 @@ def test_identical_input(self, input_param, input_shape, target_shape): loss = PerceptualLoss(**input_param) tensor = torch.randn(input_shape) result = loss(tensor, tensor) - self.assertEqual(result, torch.Tensor([0.0])) + + if "channel_wise" in input_param.keys() and input_param["channel_wise"]: + assert_allclose(result, torch.Tensor([0.0] * input_shape[1])) + else: + self.assertEqual(result, torch.Tensor([0.0])) def test_different_shape(self): with skip_if_downloading_fails(): @@ -79,12 +113,10 @@ def test_1d(self): with self.assertRaises(NotImplementedError): PerceptualLoss(spatial_dims=1) - def test_medicalnet_on_2d_data(self): - with self.assertRaises(ValueError): - PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets") - + @parameterized.expand(["medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets"]) + def test_medicalnet_on_2d_data(self, network_type): with self.assertRaises(ValueError): - PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets") + PerceptualLoss(spatial_dims=2, network_type=network_type) if __name__ == "__main__": diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 1b8245e318..b7bf2fbb11 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -45,6 +45,7 @@ class _InplaceXform(Transform): + def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -54,6 +55,7 @@ def __call__(self, data): class TestDataset(unittest.TestCase): + def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] diff --git a/tests/test_persistentdataset_dist.py b/tests/test_persistentdataset_dist.py index e69c32b1eb..c369af9e92 100644 --- a/tests/test_persistentdataset_dist.py +++ b/tests/test_persistentdataset_dist.py @@ -25,6 +25,7 @@ class _InplaceXform(Transform): + def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -34,6 +35,7 @@ def __call__(self, data): class TestDistDataset(DistTestCase): + def setUp(self): self.tempdir = tempfile.mkdtemp() @@ -58,6 +60,7 @@ def test_mp_dataset(self): class TestDistCreateDataset(DistTestCase): + def setUp(self): self.tempdir = tempfile.mkdtemp() diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index 98a5018d8e..6f872a4776 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -242,6 +242,7 @@ @skip_if_no_cpp_extension class PHLFilterTestCaseCpu(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cpu(self, test_case_description, sigmas, input, features, expected): # Create input tensors diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index 0ddfd5eaae..b410ea8722 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -150,6 +150,7 @@ @skip_if_no_cuda @skip_if_no_cpp_extension class PHLFilterTestCaseCuda(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cuda(self, test_case_description, sigmas, input, features, expected): # Create input tensors diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index dfa5eb725d..078812513d 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -37,6 +37,7 @@ class TestPNGReader(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True): test_image = np.random.randint(0, 256, size=data_shape) diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 180a6c3443..16241853b3 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -40,6 +40,7 @@ @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") @SkipIfBeforePyTorchVersion((1, 13)) # issue 6683 class TestPlot2dOr3dImage(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_tb_image(self, shape): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 0b6e8184ea..058cd616cb 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -22,6 +22,7 @@ class TestPngWrite(unittest.TestCase): + def test_write_gray(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") diff --git a/tests/test_polyval.py b/tests/test_polyval.py index 113c862cb3..f0215678af 100644 --- a/tests/test_polyval.py +++ b/tests/test_polyval.py @@ -31,6 +31,7 @@ class TestPolyval(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_floats(self, coef, x, expected): result = polyval(coef, x) diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py index e440f5cfe3..9aa498866f 100644 --- a/tests/test_prepare_batch_default.py +++ b/tests/test_prepare_batch_default.py @@ -14,96 +14,62 @@ import unittest import torch +from parameterized import parameterized from monai.engines import PrepareBatchDefault, SupervisedEvaluator from tests.utils import assert_allclose class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): return x class TestPrepareBatchDefault(unittest.TestCase): - def test_dict_content(self): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dataloader = [ - { - "image": torch.tensor([1, 2]), - "label": torch.tensor([3, 4]), - "extra1": torch.tensor([5, 6]), - "extra2": 16, - "extra3": "test", - } - ] - # set up engine - evaluator = SupervisedEvaluator( - device=device, - val_data_loader=dataloader, - epoch_length=1, - network=TestNet(), - non_blocking=False, - prepare_batch=PrepareBatchDefault(), - decollate=False, - mode="eval", - ) - evaluator.run() - output = evaluator.state.output - assert_allclose(output["image"], torch.tensor([1, 2], device=device)) - assert_allclose(output["label"], torch.tensor([3, 4], device=device)) - - def test_tensor_content(self): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dataloader = [torch.tensor([1, 2])] - # set up engine - evaluator = SupervisedEvaluator( - device=device, - val_data_loader=dataloader, - epoch_length=1, - network=torch.nn.Identity(), - non_blocking=False, - prepare_batch=PrepareBatchDefault(), - decollate=False, - mode="eval", - ) - evaluator.run() - output = evaluator.state.output - assert_allclose(output["image"], torch.tensor([1, 2], device=device)) - self.assertTrue(output["label"] is None) - - def test_pair_content(self): + @parameterized.expand( + [ + ( + [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([3, 4]), + "extra1": torch.tensor([5, 6]), + "extra2": 16, + "extra3": "test", + } + ], + TestNet(), + True, + ), # dict_content + ([torch.tensor([1, 2])], torch.nn.Identity(), True), # tensor_content + ([(torch.tensor([1, 2]), torch.tensor([3, 4]))], torch.nn.Identity(), True), # pair_content + ([], TestNet(), False), # empty_data + ] + ) + def test_prepare_batch(self, dataloader, network, should_run): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dataloader = [(torch.tensor([1, 2]), torch.tensor([3, 4]))] - - # set up engine evaluator = SupervisedEvaluator( device=device, val_data_loader=dataloader, - epoch_length=1, - network=torch.nn.Identity(), + epoch_length=len(dataloader) if should_run else 0, + network=network, non_blocking=False, prepare_batch=PrepareBatchDefault(), decollate=False, - mode="eval", + mode="eval" if should_run else "train", ) evaluator.run() - output = evaluator.state.output - assert_allclose(output["image"], torch.tensor([1, 2], device=device)) - assert_allclose(output["label"], torch.tensor([3, 4], device=device)) - def test_empty_data(self): - dataloader = [] - evaluator = SupervisedEvaluator( - val_data_loader=dataloader, - device=torch.device("cpu"), - epoch_length=0, - network=TestNet(), - non_blocking=False, - prepare_batch=PrepareBatchDefault(), - decollate=False, - ) - evaluator.run() + if should_run: + output = evaluator.state.output + if isinstance(dataloader[0], dict) or isinstance(dataloader[0], tuple): + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + assert_allclose(output["label"], torch.tensor([3, 4], device=device)) + else: + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + self.assertTrue(output["label"] is None) if __name__ == "__main__": diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py index d015cf4b2f..0c53a74834 100644 --- a/tests/test_prepare_batch_default_dist.py +++ b/tests/test_prepare_batch_default_dist.py @@ -43,11 +43,13 @@ class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): return x class DistributedPrepareBatchDefault(DistTestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) def test_compute(self, dataloaders): diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py index 1769a19e4a..f20c6e7352 100644 --- a/tests/test_prepare_batch_extra_input.py +++ b/tests/test_prepare_batch_extra_input.py @@ -36,11 +36,13 @@ class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None): return {"x": x, "t1": t1, "t2": t2, "t3": t3} class TestPrepareBatchExtraInput(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_content(self, input_args, expected_value): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py index 5a7080a225..773fcb53bf 100644 --- a/tests/test_prepare_batch_hovernet.py +++ b/tests/test_prepare_batch_hovernet.py @@ -28,11 +28,13 @@ class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16} class TestPrepareBatchHoVerNet(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) def test_content(self, input_args, expected_value): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py index 9bca24cef3..46ed461f7d 100644 --- a/tests/test_preset_filters.py +++ b/tests/test_preset_filters.py @@ -63,6 +63,7 @@ class _TestFilter: + def test_init(self, spatial_dims, size, expected): test_filter = self.filter_class(spatial_dims=spatial_dims, size=size) torch.testing.assert_allclose(expected, test_filter.filter) @@ -75,6 +76,7 @@ def test_forward(self): class TestApplyFilter(unittest.TestCase): + def test_init_and_forward_2d(self): filter_2d = torch.ones(3, 3) image_2d = torch.ones(1, 3, 3) @@ -91,6 +93,7 @@ def test_init_and_forward_3d(self): class MeanFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: self.filter_class = MeanFilter @@ -100,6 +103,7 @@ def test_init(self, spatial_dims, size, expected): class LaplaceFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: self.filter_class = LaplaceFilter @@ -109,6 +113,7 @@ def test_init(self, spatial_dims, size, expected): class EllipticalTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: self.filter_class = EllipticalFilter @@ -118,6 +123,7 @@ def test_init(self, spatial_dims, size, expected): class SharpenTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: self.filter_class = SharpenFilter diff --git a/tests/test_print_info.py b/tests/test_print_info.py index bb748c3f7b..aa152e183c 100644 --- a/tests/test_print_info.py +++ b/tests/test_print_info.py @@ -17,6 +17,7 @@ class TestPrintInfo(unittest.TestCase): + def test_print_info(self): print_debug_info() diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 4cd93c3fb2..2072aa4cfa 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -17,6 +17,7 @@ class TestPrintTransformBackends(unittest.TestCase): + def test_get_number_of_conversions(self): tr_t_or_np, *_ = get_transform_backends() self.assertGreater(len(tr_t_or_np), 0) diff --git a/tests/test_probnms.py b/tests/test_probnms.py index 8da5396fac..2b52583ad4 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -61,6 +61,7 @@ class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMS(**class_args) diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index 1f0288811e..aeb32bdb79 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -68,6 +68,7 @@ class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMSD(keys="prob_map", **class_args) diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 2b93fae196..6bee7ba262 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -29,6 +29,7 @@ class TestWorkflowProfiler(unittest.TestCase): + def setUp(self): super().setUp() diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py index 4c8c032c80..147707d2c0 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_pytorch_version_after.py @@ -38,6 +38,7 @@ class TestPytorchVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_compare(self, a, b, p, current, expected=True): """Test pytorch_after with a and b""" diff --git a/tests/test_query_memory.py b/tests/test_query_memory.py index 5e57913acb..77c34ede39 100644 --- a/tests/test_query_memory.py +++ b/tests/test_query_memory.py @@ -17,6 +17,7 @@ class TestQueryMemory(unittest.TestCase): + def test_output_str(self): self.assertTrue(isinstance(query_memory(2), str)) all_device = query_memory(-1) diff --git a/tests/test_quicknat.py b/tests/test_quicknat.py index b4b89b7d62..f6786405d2 100644 --- a/tests/test_quicknat.py +++ b/tests/test_quicknat.py @@ -38,6 +38,7 @@ @unittest.skipUnless(has_se, "squeeze_and_excitation not installed") class TestQuicknat(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_rand_adjust_contrast.py b/tests/test_rand_adjust_contrast.py index bfeedc2fcf..72d0df141e 100644 --- a/tests/test_rand_adjust_contrast.py +++ b/tests/test_rand_adjust_contrast.py @@ -25,6 +25,7 @@ class TestRandAdjustContrast(NumpyImageTestCase2D): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrast(prob=1.0, gamma=gamma) diff --git a/tests/test_rand_adjust_contrastd.py b/tests/test_rand_adjust_contrastd.py index 4037266da4..bbd5c22009 100644 --- a/tests/test_rand_adjust_contrastd.py +++ b/tests/test_rand_adjust_contrastd.py @@ -25,6 +25,7 @@ class TestRandAdjustContrastd(NumpyImageTestCase2D): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrastd("img", prob=1.0, gamma=gamma) diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 915b14bf51..2c827b7426 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -140,22 +140,22 @@ class TestRandAffine(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine - test_resampler_lazy(g, result, input_param, input_data, seed=123) + test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor") - def test_ill_cache(self): - with self.assertWarns(UserWarning): - RandAffine(cache_grid=True) + @parameterized.expand([(None,), ((1, 1, -1),)]) + def test_ill_cache(self, spatial_size): with self.assertWarns(UserWarning): - RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) + RandAffine(cache_grid=True, spatial_size=spatial_size) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) def test_skipped_transform_consistency(self, im, in_dtype): diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 113987a85c..91558ebd03 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -198,6 +198,7 @@ class TestRandAffineGrid(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index a607029c1a..950058a9e9 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -216,6 +216,7 @@ class TestRandAffined(unittest.TestCase): + @parameterized.expand(x + [y] for x, y in itertools.product(TESTS, (False, True))) def test_rand_affined(self, input_param, input_data, expected_val, track_meta): set_track_meta(track_meta) @@ -233,7 +234,9 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): lazy_init_param["keys"], lazy_init_param["mode"] = key, mode resampler = RandAffined(**lazy_init_param).set_random_state(123) expected_output = resampler(**call_param) - test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key) + test_resampler_lazy( + resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol + ) resampler.lazy = False if input_param.get("cache_grid", False): @@ -269,13 +272,10 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): self.assertEqual(len(v.applied_operations), 0) self.assertTupleEqual(v.shape, input_data[k].shape) - def test_ill_cache(self): - with self.assertWarns(UserWarning): - # spatial size is None - RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) + @parameterized.expand([(None,), ((2, -1),)]) # spatial size is None # spatial size is dynamic + def test_ill_cache(self, spatial_size): with self.assertWarns(UserWarning): - # spatial size is dynamic - RandAffined(device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, keys=("img", "seg")) + RandAffined(device=device, spatial_size=spatial_size, prob=1.0, cache_grid=True, keys=("img", "seg")) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 81e42372db..9c465a0bcb 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -23,6 +23,7 @@ class TestRandAxisFlip(NumpyImageTestCase2D): + def test_correct_results(self): for p in TEST_NDARRAYS_ALL: flip = RandAxisFlip(prob=1.0) diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 75357b23e1..d3abef1be4 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -23,6 +23,7 @@ class TestRandAxisFlip(NumpyImageTestCase3D): + def test_correct_results(self): for p in TEST_NDARRAYS_ALL: flip = RandAxisFlipd(keys="img", prob=1.0) diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py index 16f615146f..333a9ecba5 100644 --- a/tests/test_rand_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -30,6 +30,7 @@ class TestRandBiasField(unittest.TestCase): + @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output_shape(self, class_args, img_shape): for p in TEST_NDARRAYS: diff --git a/tests/test_rand_bias_fieldd.py b/tests/test_rand_bias_fieldd.py index 2b8a60289d..1f174fa397 100644 --- a/tests/test_rand_bias_fieldd.py +++ b/tests/test_rand_bias_fieldd.py @@ -28,6 +28,7 @@ class TestRandBiasFieldd(unittest.TestCase): + @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output_shape(self, class_args, img_shape): key = "img" diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index 8c3876f10b..ac857f9184 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -63,6 +63,7 @@ class TestRandCoarseDropout(unittest.TestCase): + @parameterized.expand( [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] ) diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index 7b16f992b7..bfc6a2f27f 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -63,6 +63,7 @@ class TestRandCoarseDropoutd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, input_param, input_data): dropout = RandCoarseDropoutd(**input_param) diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py index adfb722b42..39e62c22a8 100644 --- a/tests/test_rand_coarse_shuffle.py +++ b/tests/test_rand_coarse_shuffle.py @@ -52,6 +52,7 @@ class TestRandCoarseShuffle(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shuffle(self, input_param, input_data, expected_val): g = RandCoarseShuffle(**input_param) diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py index 3b5a1434f4..f49066efd9 100644 --- a/tests/test_rand_coarse_shuffled.py +++ b/tests/test_rand_coarse_shuffled.py @@ -46,6 +46,7 @@ class TestRandCoarseShuffled(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shuffle(self, input_param, input_data, expected_val): g = RandCoarseShuffled(**input_param) diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 88d2631ca5..743b894d75 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -127,6 +127,7 @@ class TestRandCropByLabelClasses(unittest.TestCase): + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 748f26f1ff..8908c456ee 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -120,6 +120,7 @@ class TestRandCropByLabelClassesd(unittest.TestCase): + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index 98af6b0b5e..66e7a5e849 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -96,6 +96,7 @@ class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod def convert_data_type(im_type, d, keys=("img", "image", "label")): out = deepcopy(d) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 1b57548d12..11381e226d 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -107,6 +107,7 @@ class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod def convert_data_type(im_type, d, keys=("img", "image", "label")): out = deepcopy(d) diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py index 33e0667723..3f473897dd 100644 --- a/tests/test_rand_cucim_dict_transform.py +++ b/tests/test_rand_cucim_dict_transform.py @@ -78,6 +78,7 @@ @unittest.skipUnless(HAS_CUPY, "CuPy is required.") @unittest.skipUnless(has_cut, "cuCIM transforms are required.") class TestRandCuCIMDict(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_COLOR_JITTER_1, diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py index 37d8e29f1d..ce731a05ae 100644 --- a/tests/test_rand_cucim_transform.py +++ b/tests/test_rand_cucim_transform.py @@ -78,6 +78,7 @@ @unittest.skipUnless(HAS_CUPY, "CuPy is required.") @unittest.skipUnless(has_cut, "cuCIM transforms are required.") class TestRandCuCIM(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_COLOR_JITTER_1, diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py index 58b64ae596..88fc1333ec 100644 --- a/tests/test_rand_deform_grid.py +++ b/tests/test_rand_deform_grid.py @@ -126,6 +126,7 @@ class TestRandDeformGrid(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_rand_deform_grid(self, input_param, input_data, expected_val): g = RandDeformGrid(**input_param) diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index c59052854f..1f3d389a93 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -110,6 +110,7 @@ class TestRand2DElastic(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index 0ff3ef6129..5bfa8a6e83 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -86,6 +86,7 @@ class TestRand3DElastic(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index d0fbd5aa88..10aa116192 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -160,6 +160,7 @@ class TestRand2DElasticd(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index e058293584..3838f43f29 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -139,6 +139,7 @@ class TestRand3DElasticd(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index c3b0bfdede..faeae94cab 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -28,6 +28,7 @@ class TestRandFlip(NumpyImageTestCase2D): + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index be5394c172..a34aa58ed2 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -26,6 +26,7 @@ class TestRandFlipd(NumpyImageTestCase2D): + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index 7d4d04ff3f..233b4dd1b6 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -22,21 +22,24 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append(("test_zero_mean", p, 0, 0.1)) - TESTS.append(("test_non_zero_mean", p, 1, 0.5)) + TESTS.append(("test_zero_mean", p, 0, 0.1, True)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5, True)) + TESTS.append(("test_no_sample_std", p, 1, 0.5, False)) class TestRandGaussianNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) - def test_correct_results(self, _, im_type, mean, std): + def test_correct_results(self, _, im_type, mean, std, sample_std): seed = 0 - gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) + gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn(im) np.random.seed(seed) np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + _std = np.random.uniform(0, std) if sample_std else std + expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape) if isinstance(noised, torch.Tensor): noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 24fc19f226..e3df196be2 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -22,23 +22,28 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) - TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1, True]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5, True]) + TESTS.append(["test_no_sample_std", p, ["img1", "img2"], 1, 0.5, False]) seed = 0 class TestRandGaussianNoised(NumpyImageTestCase2D): + @parameterized.expand(TESTS) - def test_correct_results(self, _, im_type, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) + def test_correct_results(self, _, im_type, keys, mean, std, sample_std): + gaussian_fn = RandGaussianNoised( + keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std + ) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn({k: im for k in keys}) np.random.seed(seed) # simulate the randomize() of transform np.random.random() - noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + _std = np.random.uniform(0, std) if sample_std else std + noise = np.random.normal(mean, _std, size=self.imt.shape) for k in keys: expected = self.imt + noise if isinstance(noised[k], torch.Tensor): diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 8dff69cd4c..ee8604c14b 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -128,6 +128,7 @@ class TestRandGaussianSharpen(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): converter = RandGaussianSharpen(**arguments) diff --git a/tests/test_rand_gaussian_sharpend.py b/tests/test_rand_gaussian_sharpend.py index 4c32880053..b9bae529db 100644 --- a/tests/test_rand_gaussian_sharpend.py +++ b/tests/test_rand_gaussian_sharpend.py @@ -131,6 +131,7 @@ class TestRandGaussianSharpend(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): converter = RandGaussianSharpend(**arguments) diff --git a/tests/test_rand_gaussian_smooth.py b/tests/test_rand_gaussian_smooth.py index 9fb91a38a1..8bb36ca0fa 100644 --- a/tests/test_rand_gaussian_smooth.py +++ b/tests/test_rand_gaussian_smooth.py @@ -86,6 +86,7 @@ class TestRandGaussianSmooth(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): converter = RandGaussianSmooth(**arguments) diff --git a/tests/test_rand_gaussian_smoothd.py b/tests/test_rand_gaussian_smoothd.py index d312494e46..a93b355184 100644 --- a/tests/test_rand_gaussian_smoothd.py +++ b/tests/test_rand_gaussian_smoothd.py @@ -86,6 +86,7 @@ class TestRandGaussianSmoothd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): converter = RandGaussianSmoothd(**arguments) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index a0d18ae7f3..5ef249a1f4 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -32,6 +32,7 @@ class TestRandGibbsNoise(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() @@ -89,6 +90,15 @@ def test_alpha(self, im_shape, input_type): self.assertGreaterEqual(t.sampled_alpha, 0.5) self.assertLessEqual(t.sampled_alpha, 0.51) + @parameterized.expand(TEST_CASES) + def test_alpha_single_value(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) + alpha = 0.01 + t = RandGibbsNoise(1.0, alpha) + _ = t(deepcopy(im)) + self.assertGreaterEqual(t.sampled_alpha, 0) + self.assertLessEqual(t.sampled_alpha, 0.01) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 4120f967e2..382290dd39 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -34,6 +34,7 @@ class TestRandGibbsNoised(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() @@ -104,6 +105,14 @@ def test_alpha(self, im_shape, input_type): _ = t(deepcopy(data)) self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51) + @parameterized.expand(TEST_CASES) + def test_alpha_single_value(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) + alpha = 0.01 + t = RandGibbsNoised(KEYS, 1.0, alpha) + _ = t(deepcopy(data)) + self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py index 8131a2382a..e07c311b25 100644 --- a/tests/test_rand_grid_distortion.py +++ b/tests/test_rand_grid_distortion.py @@ -84,6 +84,7 @@ class TestRandGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val): g = RandGridDistortion(**input_param) diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py index 9f8ed3b9e6..f28e0ae86e 100644 --- a/tests/test_rand_grid_distortiond.py +++ b/tests/test_rand_grid_distortiond.py @@ -77,6 +77,7 @@ class TestRandGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask): g = RandGridDistortiond(**input_param) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 494330584a..26863f01b2 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -105,6 +105,7 @@ class TestRandGridPatch(unittest.TestCase): + def setUp(self): set_determinism(seed=1234) diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 23ca4a7881..031e834512 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -85,6 +85,7 @@ class TestRandGridPatchd(unittest.TestCase): + def setUp(self): set_determinism(seed=1234) diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index 318dad9dfa..785e24e53b 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -56,6 +56,7 @@ class TestRandHistogramShift(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 45e81ab012..fced270e90 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -61,6 +61,7 @@ class TestRandHistogramShiftD(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): g = RandHistogramShiftd(**input_param) diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 4e7d59329b..7a9dd4288d 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -29,6 +29,7 @@ class TestRandKSpaceSpikeNoise(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 3e1c11b2d9..86d4256637 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -30,6 +30,7 @@ class TestKSpaceSpikeNoised(unittest.TestCase): + def setUp(self): set_determinism(0) super().setUp() diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py index 1f14499bc0..98a324aec5 100644 --- a/tests/test_rand_lambda.py +++ b/tests/test_rand_lambda.py @@ -37,6 +37,7 @@ def __call__(self, data): class TestRandLambda(unittest.TestCase): + def check(self, tr: RandLambda, img, img_orig_type, out, expected=None): # input shouldn't change self.assertIsInstance(img, img_orig_type) diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 6b60a3fe70..5247d79843 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -37,6 +37,7 @@ def __call__(self, data): class TestRandLambdad(unittest.TestCase): + def check(self, tr: RandLambdad, input: dict, out: dict, expected: dict): if isinstance(input["img"], MetaTensor): self.assertEqual(len(input["img"].applied_operations), 0) diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index fe7135835e..8dd1c48e29 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -27,6 +27,7 @@ class TestRandRicianNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) def test_correct_results(self, _, in_type, mean, std): seed = 0 diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index ae0acab4eb..a190ba866d 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -29,6 +29,7 @@ class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): + @parameterized.expand(TESTS) def test_correct_results(self, _, in_type, keys, mean, std): rician_fn = RandRicianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index ca3eda3b12..c54229dcfe 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -73,6 +73,7 @@ class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): init_param = { @@ -112,6 +113,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, @unittest.skipIf(USE_COMPILED, "unit tests not for compiled version.") class TestRandRotate3D(NumpyImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): init_param = { @@ -146,6 +148,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, class TestRandRotateDtype(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 88f88bf422..be2e658b78 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -23,6 +23,7 @@ class TestRandRotate90(NumpyImageTestCase2D): + def test_default(self): rotate = RandRotate90() for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index 23e9025c08..02836b5dd8 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -23,6 +23,7 @@ class TestRandRotate90d(NumpyImageTestCase2D): + def test_default(self): key = "test" rotate = RandRotate90d(keys=key) diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index a5a377b02f..71d0f67b63 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -109,6 +109,7 @@ class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): init_param = { @@ -153,6 +154,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, @unittest.skipIf(USE_COMPILED, "unit tests not for compiled version.") class TestRandRotated3D(NumpyImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): init_param = { diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index a857c0cefb..7e999c00b3 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -21,6 +21,7 @@ class TestRandScaleIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value(self, p): scaler = RandScaleIntensity(factors=0.5, prob=1.0) diff --git a/tests/test_rand_scale_intensity_fixed_mean.py b/tests/test_rand_scale_intensity_fixed_mean.py index f43adab32f..9324c711fa 100644 --- a/tests/test_rand_scale_intensity_fixed_mean.py +++ b/tests/test_rand_scale_intensity_fixed_mean.py @@ -21,6 +21,7 @@ class TestRandScaleIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value(self, p): scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5) diff --git a/tests/test_rand_scale_intensity_fixed_meand.py b/tests/test_rand_scale_intensity_fixed_meand.py index c85c764a55..8c127ac130 100644 --- a/tests/test_rand_scale_intensity_fixed_meand.py +++ b/tests/test_rand_scale_intensity_fixed_meand.py @@ -20,6 +20,7 @@ class TestRandScaleIntensityFixedMeand(NumpyImageTestCase2D): + def test_value(self): key = "img" for p in TEST_NDARRAYS: diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index 8d928ac157..32c96f0313 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -20,6 +20,7 @@ class TestRandScaleIntensityd(NumpyImageTestCase2D): + def test_value(self): key = "img" for p in TEST_NDARRAYS: diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index 01ac55f7b8..907773ccf5 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -21,6 +21,7 @@ class TestRandShiftIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value(self, p): shifter = RandShiftIntensity(offsets=1.0, prob=1.0) diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 7522676eb0..51675e324c 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -21,6 +21,7 @@ class TestRandShiftIntensityd(NumpyImageTestCase2D): + def test_value(self): key = "img" for p in TEST_NDARRAYS: diff --git a/tests/test_rand_simulate_low_resolution.py b/tests/test_rand_simulate_low_resolution.py index 7d05faad36..6aa586fb0b 100644 --- a/tests/test_rand_simulate_low_resolution.py +++ b/tests/test_rand_simulate_low_resolution.py @@ -71,6 +71,7 @@ class TestRandGaussianSmooth(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): randsimlowres = RandSimulateLowResolution(**arguments) diff --git a/tests/test_rand_simulate_low_resolutiond.py b/tests/test_rand_simulate_low_resolutiond.py index f058ec3b2b..5ec84eba1d 100644 --- a/tests/test_rand_simulate_low_resolutiond.py +++ b/tests/test_rand_simulate_low_resolutiond.py @@ -60,6 +60,7 @@ class TestRandGaussianSmoothd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, arguments, image, expected_data): converter = RandSimulateLowResolutiond(**arguments) diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index b37dacd643..cb53e94b7d 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -90,6 +90,7 @@ class TestRandSpatialCropSamplesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, *TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index 535fb7cb20..0ac5e9482e 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -22,6 +22,7 @@ class TestRandStdShiftIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value(self, p): np.random.seed(0) diff --git a/tests/test_rand_std_shift_intensityd.py b/tests/test_rand_std_shift_intensityd.py index 31209ee754..1fd0c5d2a8 100644 --- a/tests/test_rand_std_shift_intensityd.py +++ b/tests/test_rand_std_shift_intensityd.py @@ -20,6 +20,7 @@ class TestRandStdShiftIntensityd(NumpyImageTestCase2D): + def test_value(self): for p in TEST_NDARRAYS: key = "img" diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 9d37779613..1524442f61 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -148,6 +148,7 @@ def get_data(ndim): class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, expected_centers): crop = RandWeightedCropd(**init_params) diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index d52b79d8cf..2da04fd652 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -33,6 +33,7 @@ class TestRandZoom(NumpyImageTestCase2D): + @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size, align_corners=None): for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index bb0495c793..bcbf188310 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -31,6 +31,7 @@ class TestRandZoomd(NumpyImageTestCase2D): + @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size): key = "img" diff --git a/tests/test_randidentity.py b/tests/test_randidentity.py index 09dc055b4e..3a8936f2d2 100644 --- a/tests/test_randidentity.py +++ b/tests/test_randidentity.py @@ -19,11 +19,13 @@ class T(mt.Transform): + def __call__(self, x): return x * 2 class TestIdentity(NumpyImageTestCase2D): + def test_identity(self): for p in TEST_NDARRAYS: img = p(self.imt) diff --git a/tests/test_random_order.py b/tests/test_random_order.py index e5507fafca..b38d2398fb 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -30,6 +30,7 @@ class InvC(Inv): + def __init__(self, keys): super().__init__(keys) self.fwd_fn = lambda x: x + 1 @@ -37,6 +38,7 @@ def __init__(self, keys): class InvD(Inv): + def __init__(self, keys): super().__init__(keys) self.fwd_fn = lambda x: x * 100 @@ -55,6 +57,7 @@ def __init__(self, keys): class TestRandomOrder(unittest.TestCase): + def test_empty_compose(self): c = RandomOrder() i = 1 @@ -113,6 +116,7 @@ def test_inverse(self, transform, invertible, use_metatensor): class TestRandomOrderAPITests(unittest.TestCase): + @staticmethod def data_from_keys(keys): if keys is None: diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py index 96854a6db8..56d5293130 100644 --- a/tests/test_randomizable.py +++ b/tests/test_randomizable.py @@ -19,11 +19,13 @@ class RandTest(Randomizable): + def randomize(self, data=None): pass class TestRandomizable(unittest.TestCase): + def test_default(self): inst = RandTest() r1 = inst.R.rand() diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py index 3a0995be68..919f9299bf 100644 --- a/tests/test_randomizable_transform_type.py +++ b/tests/test_randomizable_transform_type.py @@ -21,11 +21,13 @@ class InheritsInterface(RandomizableTrait): class InheritsImplementation(RandomizableTransform): + def __call__(self, data): return data class TestRandomizableTransformType(unittest.TestCase): + def test_is_randomizable_transform_type(self): inst = InheritsInterface() self.assertIsInstance(inst, RandomizableTrait) diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py index 82f9adf473..7ad06dfd2a 100644 --- a/tests/test_randtorchvisiond.py +++ b/tests/test_randtorchvisiond.py @@ -52,6 +52,7 @@ class TestRandTorchVisiond(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index 40cd36f31d..fd02e3bdc9 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -23,6 +23,7 @@ class DistributedRankFilterTest(DistTestCase): + def setUp(self): self.log_dir = tempfile.TemporaryDirectory() @@ -50,6 +51,7 @@ def tearDown(self) -> None: class SingleRankFilterTest(unittest.TestCase): + def tearDown(self) -> None: self.log_dir.cleanup() diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py index 38adb9617b..1815000777 100644 --- a/tests/test_recon_net_utils.py +++ b/tests/test_recon_net_utils.py @@ -49,6 +49,7 @@ class TestReconNetUtils(unittest.TestCase): + @parameterized.expand(TEST_RESHAPE) def test_reshape_channel_complex(self, test_data): result = reshape_complex_to_channel_dim(test_data) diff --git a/tests/test_reference_based_normalize_intensity.py b/tests/test_reference_based_normalize_intensity.py index 8d2715f983..2d946af118 100644 --- a/tests/test_reference_based_normalize_intensity.py +++ b/tests/test_reference_based_normalize_intensity.py @@ -52,6 +52,7 @@ class TestDetailedNormalizeIntensityd(unittest.TestCase): + @parameterized.expand(TESTS) def test_target_mean_std(self, args, data, normalized_data, normalized_target, mean, std): dtype = data[args["keys"][0]].dtype diff --git a/tests/test_reference_based_spatial_cropd.py b/tests/test_reference_based_spatial_cropd.py index d5777482c0..83cd9c4a5d 100644 --- a/tests/test_reference_based_spatial_cropd.py +++ b/tests/test_reference_based_spatial_cropd.py @@ -46,6 +46,7 @@ class TestTargetBasedSpatialCropd(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, args, data, expected_shape): cropper = ReferenceBasedSpatialCropd(keys=args["keys"], ref_key=args["ref_key"]) diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py index 07d56a16df..1f02bb01a7 100644 --- a/tests/test_reference_resolver.py +++ b/tests/test_reference_resolver.py @@ -70,6 +70,7 @@ class TestReferenceResolver(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) def test_resolve(self, configs, expected_id, output_type): locator = ComponentLocator() diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 6cd973c32e..e8f82eb0c2 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -32,6 +32,7 @@ class TestRegLossIntegration(unittest.TestCase): + def setUp(self): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -61,6 +62,7 @@ def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1): # define a one layer model class OnelayerNet(nn.Module): + def __init__(self): super().__init__() self.layer = nn.Sequential( diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 0000000000..4df60b9808 --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd +from monai.utils import set_determinism + + +class TestMixup(unittest.TestCase): + + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + def test_mixup(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup = MixUp(6, 1.0) + output = mixup(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + + with self.assertRaises(ValueError): + MixUp(6, -0.5) + + mixup = MixUp(6, 0.5) + for dims in [2, 3]: + with self.assertRaises(ValueError): + shape = (5, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup(sample) + + def test_mixupd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + mixup = MixUpd(["a", "b"], 6) + output = mixup(sample) + self.assertTrue(torch.allclose(output["a"], output["b"])) + + with self.assertRaises(ValueError): + MixUpd(["k1", "k2"], 6, -0.5) + + +class TestCutMix(unittest.TestCase): + + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + def test_cutmix(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutmix = CutMix(6, 1.0) + output = cutmix(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + + def test_cutmixd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + label = torch.randint(0, 1, shape) + sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} + cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + output = cutmix(sample) + # croppings are different on each application + self.assertTrue(not torch.allclose(output["a"], output["b"])) + # but mixing of labels is not affected by it + self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) + + +class TestCutOut(unittest.TestCase): + + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + def test_cutout(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutout = CutOut(6, 1.0) + output = cutout(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_regunet.py b/tests/test_regunet.py index 04ff60ef30..3100d7660c 100644 --- a/tests/test_regunet.py +++ b/tests/test_regunet.py @@ -63,6 +63,7 @@ class TestREGUNET(unittest.TestCase): + @parameterized.expand(TEST_CASE_REGUNET_2D + TEST_CASE_REGUNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = RegUNet(**input_param).to(device) diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py index eebe9d8694..fa07671d03 100644 --- a/tests/test_regunet_block.py +++ b/tests/test_regunet_block.py @@ -65,6 +65,7 @@ class TestRegistrationResidualConvBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_RESIDUAL) def test_shape(self, input_param, input_shape, expected_shape): net = RegistrationResidualConvBlock(**input_param) @@ -74,6 +75,7 @@ def test_shape(self, input_param, input_shape, expected_shape): class TestRegistrationDownSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) def test_shape(self, input_param, input_shape, expected_shape): net = RegistrationDownSampleBlock(**input_param) @@ -88,6 +90,7 @@ def test_ill_shape(self): class TestRegistrationExtractionBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_EXTRACTION) def test_shape(self, input_param, input_shapes, image_size, expected_shape): net = RegistrationExtractionBlock(**input_param) diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 90b1b79b03..7da00ee75d 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -24,6 +24,7 @@ class TestRemoveRepeatedChannel(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChannel(**input_param)(input_data) diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 6d36d32f6f..08ec7fb44c 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -34,6 +34,7 @@ class TestRemoveRepeatedChanneld(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChanneld(**input_param)(input_data) diff --git a/tests/test_remove_small_objects.py b/tests/test_remove_small_objects.py index 200f4ed9b2..633a6d9a99 100644 --- a/tests/test_remove_small_objects.py +++ b/tests/test_remove_small_objects.py @@ -55,6 +55,7 @@ @SkipIfNoModule("skimage.morphology") class TestRemoveSmallObjects(unittest.TestCase): + @parameterized.expand(TESTS) def test_remove_small_objects(self, dtype, im_type, lbl, expected, params=None): params = params or {} diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 0ae5743836..82d1d92bd2 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -24,6 +24,7 @@ class TestRepeatChannel(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 9f7872135d..2be13a08d1 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -31,6 +31,7 @@ class TestRepeatChanneld(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChanneld(**input_param)(input_data) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py index cac3fd39e5..f3964ac65d 100644 --- a/tests/test_replace_module.py +++ b/tests/test_replace_module.py @@ -32,6 +32,7 @@ class TestReplaceModule(unittest.TestCase): + def setUp(self): self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) self.num_relus = self.get_num_modules(torch.nn.ReLU) diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py index b1a3d82a17..065a7509a4 100644 --- a/tests/test_require_pkg.py +++ b/tests/test_require_pkg.py @@ -17,7 +17,9 @@ class TestRequirePkg(unittest.TestCase): + def test_class(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) class TestClass: pass @@ -25,6 +27,7 @@ class TestClass: TestClass() def test_function(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) def test_func(x): return x @@ -32,6 +35,7 @@ def test_func(x): test_func(x=None) def test_warning(self): + @require_pkg(pkg_name="test123", raise_error=False) def test_func(x): return x diff --git a/tests/test_resample.py b/tests/test_resample.py index c90dc5f13d..68b08b8b87 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -35,6 +35,7 @@ def rotate_90_2d(): class TestResampleFunction(unittest.TestCase): + @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): out = resample(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"}) diff --git a/tests/test_resample_backends.py b/tests/test_resample_backends.py index 97ee0731e8..7ddd9c7ec2 100644 --- a/tests/test_resample_backends.py +++ b/tests/test_resample_backends.py @@ -44,6 +44,7 @@ @SkipIfBeforePyTorchVersion((1, 9, 1)) class TestResampleBackends(unittest.TestCase): + @parameterized.expand(TEST_IDENTITY) def test_resample_identity(self, input_param, im_type, interp, padding, input_shape): """test resampling of an identity grid with padding 2, im_type, interp, padding, input_shape""" diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py index ae52492953..ac5cb25bb3 100644 --- a/tests/test_resample_datalist.py +++ b/tests/test_resample_datalist.py @@ -32,6 +32,7 @@ class TestResampleDatalist(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_value_shape(self, input_param, expected): result = resample_datalist(**input_param) diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index b12ffd04be..f0d34547a7 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -46,6 +46,7 @@ def get_rand_fname(len=10, suffix=".nii.gz"): @unittest.skipUnless(has_itk, "itk not installed") class TestResampleToMatch(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index 748e830bdd..9d104bf392 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -36,6 +36,7 @@ def update_fname(d): class TestResampleToMatchd(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 50ea344090..af0db657aa 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -152,6 +152,7 @@ class TestResample(unittest.TestCase): + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) diff --git a/tests/test_resize.py b/tests/test_resize.py index 97a8f8dab2..d4c57e2742 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -21,7 +21,14 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Resize from tests.lazy_transforms_utils import test_resampler_lazy -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, is_tf32_env, pytorch_after +from tests.utils import ( + TEST_NDARRAYS_ALL, + NumpyImageTestCase2D, + SkipIfAtLeastPyTorchVersion, + assert_allclose, + is_tf32_env, + pytorch_after, +) TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)] @@ -39,6 +46,7 @@ class TestResize(NumpyImageTestCase2D): + def test_invalid_inputs(self): with self.assertRaises(ValueError): resize = Resize(spatial_size=(128, 128, 3), mode="order") @@ -111,6 +119,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): ) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_1, TEST_CASE_3, TEST_CASE_4]) + @SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445 def test_longest_shape(self, input_param, expected_shape): input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) input_param["size_mode"] = "longest" diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 287df039b8..daf257f89f 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -48,6 +48,7 @@ class TestResizeWithPadOrCrop(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape, _): for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 471144a609..391e0feb22 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -46,6 +46,7 @@ class TestResizeWithPadOrCropd(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_resized.py b/tests/test_resized.py index bd711b33d8..243a4e6622 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -21,7 +21,13 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Invertd, Resize, Resized from tests.lazy_transforms_utils import test_resampler_lazy -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion +from tests.utils import ( + TEST_NDARRAYS_ALL, + NumpyImageTestCase2D, + SkipIfAtLeastPyTorchVersion, + assert_allclose, + test_local_inversion, +) TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -58,7 +64,9 @@ ] +@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445 class TestResized(NumpyImageTestCase2D): + def test_invalid_inputs(self): with self.assertRaises(ValueError): resize = Resized(keys="img", spatial_size=(128, 128, 3), mode="order") diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 15ec6353f9..449edba4bf 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -24,6 +24,7 @@ from monai.networks import eval_mode from monai.networks.nets import ( ResNet, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, @@ -36,7 +37,14 @@ ) from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save +from tests.utils import ( + SkipIfNoModule, + equal_state_dict, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + test_script_save, +) if TYPE_CHECKING: import torchvision @@ -191,7 +199,17 @@ ] +CASE_EXTRACT_FEATURES = [ + ( + {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + [1, 1, 64, 64, 64], + ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), + ) +] + + class TestResNet(unittest.TestCase): + def setUp(self): self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth") @@ -269,5 +287,25 @@ def test_script(self, model, input_param, input_shape, expected_shape): test_script_save(net, test_data) +@SkipIfNoModule("hf_hub_download") +class TestExtractFeatures(unittest.TestCase): + + @parameterized.expand(CASE_EXTRACT_FEATURES) + def test_shape(self, input_param, input_shape, expected_shapes): + device = "cuda" if torch.cuda.is_available() else "cpu" + + with skip_if_downloading_fails(): + net = ResNetFeatures(**input_param).to(device) + + # run inference with random tensor + with eval_mode(net): + features = net(torch.randn(input_shape).to(device)) + + # check output shape + self.assertEqual(len(features), len(expected_shapes)) + for feature, expected_shape in zip(features, expected_shapes): + self.assertEqual(feature.shape, torch.Size(expected_shape)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py index 074a5b63fa..f36708d5b3 100644 --- a/tests/test_retinanet.py +++ b/tests/test_retinanet.py @@ -101,6 +101,7 @@ @unittest.skipUnless(has_torchvision, "Requires torchvision") @skip_if_quick class TestRetinaNet(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_retina_shape(self, model, input_param, input_shape): backbone = model(**input_param) diff --git a/tests/test_retinanet_detector.py b/tests/test_retinanet_detector.py index 7292bc0c49..691254fd87 100644 --- a/tests/test_retinanet_detector.py +++ b/tests/test_retinanet_detector.py @@ -93,6 +93,7 @@ class NaiveNetwork(torch.nn.Module): + def __init__(self, spatial_dims, num_classes, **kwargs): super().__init__() self.spatial_dims = spatial_dims @@ -114,6 +115,7 @@ def forward(self, images): @unittest.skipUnless(has_torchvision, "Requires torchvision") @skip_if_quick class TestRetinaNetDetector(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape): returned_layers = [1] diff --git a/tests/test_retinanet_predict_utils.py b/tests/test_retinanet_predict_utils.py index d97806e91c..d909699469 100644 --- a/tests/test_retinanet_predict_utils.py +++ b/tests/test_retinanet_predict_utils.py @@ -85,6 +85,7 @@ class NaiveNetwork(torch.nn.Module): + def __init__(self, spatial_dims, num_classes, **kwargs): super().__init__() self.spatial_dims = spatial_dims @@ -103,6 +104,7 @@ def forward(self, images): class NaiveNetwork2(torch.nn.Module): + def __init__(self, spatial_dims, num_classes, **kwargs): super().__init__() self.spatial_dims = spatial_dims @@ -121,6 +123,7 @@ def forward(self, images): class TestPredictor(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_naive_predictor(self, input_param, input_shape): net = NaiveNetwork(**input_param) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 95c63e65f7..19fbd1409f 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -52,6 +52,7 @@ class TestRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): init_param = { @@ -90,6 +91,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotate3D(NumpyImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): init_param = { diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 0948469df9..ebc3fba7e0 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -31,6 +31,7 @@ class TestRotate90(NumpyImageTestCase2D): + def test_rotate90_default(self): rotate = Rotate90() for p in TEST_NDARRAYS_ALL: @@ -102,6 +103,7 @@ def test_prob_k_spatial_axes(self): class TestRotate903d(NumpyImageTestCase3D): + def test_rotate90_default(self): rotate = Rotate90() for p in TEST_NDARRAYS_ALL: @@ -169,6 +171,7 @@ def test_prob_k_spatial_axes(self): @unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.") class TestRot90Consistency(unittest.TestCase): + @parameterized.expand([[2], [3], [4]]) def test_affine_rot90(self, s): """s""" diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 08d3a97498..ffe920992a 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -22,6 +22,7 @@ class TestRotate90d(NumpyImageTestCase2D): + def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 3755ab1344..28ca755661 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -43,6 +43,7 @@ @unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False") class TestRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): init_param = { @@ -94,6 +95,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al @unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False") class TestRotated3D(NumpyImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): init_param = { @@ -143,6 +145,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al @unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False") class TestRotated3DXY(NumpyImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated( diff --git a/tests/test_safe_dtype_range.py b/tests/test_safe_dtype_range.py index 73f9607d7d..61b55635ae 100644 --- a/tests/test_safe_dtype_range.py +++ b/tests/test_safe_dtype_range.py @@ -54,6 +54,7 @@ class TesSafeDtypeRange(unittest.TestCase): + @parameterized.expand(TESTS) def test_safe_dtype_range(self, in_image, im_out, out_dtype): result = safe_dtype_range(in_image, out_dtype) diff --git a/tests/test_saliency_inferer.py b/tests/test_saliency_inferer.py index 4efe30d7a6..70ec048d1c 100644 --- a/tests/test_saliency_inferer.py +++ b/tests/test_saliency_inferer.py @@ -28,6 +28,7 @@ class TestSaliencyInferer(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, cam_name): model = DenseNet( diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py index 02b7926392..a183689970 100644 --- a/tests/test_sample_slices.py +++ b/tests/test_sample_slices.py @@ -32,6 +32,7 @@ class TestSampleSlices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, input_data, dim, as_indices, vals, expected_result): for p in TEST_NDARRAYS: diff --git a/tests/test_sampler_dist.py b/tests/test_sampler_dist.py index b2f86c54cc..b8bd1c7a9f 100644 --- a/tests/test_sampler_dist.py +++ b/tests/test_sampler_dist.py @@ -24,6 +24,7 @@ class DistributedSamplerTest(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_even(self): data = [1, 2, 3, 4, 5] diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index dd0b213bd6..9a7d4fc3f5 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -26,6 +26,7 @@ class TestSaveClassificationd(unittest.TestCase): + def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: data = [ diff --git a/tests/test_save_image.py b/tests/test_save_image.py index d88db201ce..ed7061095d 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -42,6 +42,7 @@ @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImage(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_saved_content(self, test_data, meta_data, output_ext, resample): if meta_data is not None: diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index ab0b9c0d9f..d2095a7554 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -54,6 +54,7 @@ @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImaged(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: @@ -73,7 +74,9 @@ def test_saved_content(self, test_data, output_ext, resample): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_custom_folderlayout(self, test_data, output_ext, resample): + class TestFolderLayout(FolderLayoutBase): + def __init__(self, basepath: Path, extension: str, makedirs: bool): self.basepath = basepath self.ext = extension diff --git a/tests/test_save_state.py b/tests/test_save_state.py index 8ab7080700..0581a3ce1f 100644 --- a/tests/test_save_state.py +++ b/tests/test_save_state.py @@ -43,6 +43,7 @@ class TestSaveState(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_1, diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index b7f89cdfde..7c60287e2d 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -100,6 +100,7 @@ class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) @@ -109,6 +110,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] ) @@ -119,6 +121,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): @skip_if_no_cuda class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) @@ -129,6 +132,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): @skip_if_no_cuda class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] ) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 6da4f24c62..14e403e238 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -60,6 +60,7 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] ) diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py index 7e7176e2bb..3bb4056046 100644 --- a/tests/test_savitzky_golay_smoothd.py +++ b/tests/test_savitzky_golay_smoothd.py @@ -60,6 +60,7 @@ class TestSavitzkyGolaySmoothd(unittest.TestCase): + @parameterized.expand( [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] ) diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 57a7da1780..17dfe305b2 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -22,6 +22,7 @@ class TestScaleIntensity(NumpyImageTestCase2D): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_range_scale(self, p): scaler = ScaleIntensity(minv=1.0, maxv=2.0) diff --git a/tests/test_scale_intensity_fixed_mean.py b/tests/test_scale_intensity_fixed_mean.py index afbcd46141..35d38ef0b1 100644 --- a/tests/test_scale_intensity_fixed_mean.py +++ b/tests/test_scale_intensity_fixed_mean.py @@ -21,6 +21,7 @@ class TestScaleIntensityFixedMean(NumpyImageTestCase2D): + def test_factor_scale(self): for p in TEST_NDARRAYS: scaler = ScaleIntensityFixedMean(factor=0.1, fixed_mean=False) diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index 898f4dfb45..6013a237db 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -20,6 +20,7 @@ class IntensityScaleIntensityRange(NumpyImageTestCase2D): + def test_image_scale_intensity_range(self): scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8) for p in TEST_NDARRAYS: diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 583dcec07e..7c3a684a00 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -20,6 +20,7 @@ class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): + def test_scaling(self): img = self.imt[0] lower = 10 diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 8e2511d9e4..ab0347fbbf 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -20,6 +20,7 @@ class TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D): + def test_scaling(self): img = self.imt lower = 10 diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index 724acf1c73..cc3f1220e7 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -18,6 +18,7 @@ class IntensityScaleIntensityRanged(NumpyImageTestCase2D): + def test_image_scale_intensity_ranged(self): key = "img" scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80) diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 6705cfda9d..88beece894 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -20,6 +20,7 @@ class TestScaleIntensityd(NumpyImageTestCase2D): + def test_range_scale(self): key = "img" for p in TEST_NDARRAYS: diff --git a/tests/test_se_block.py b/tests/test_se_block.py index de129f4d55..ca60643635 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -63,6 +63,7 @@ class TestSEBlockLayer(unittest.TestCase): + @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SEBlock(**input_param).to(device) diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index c97e459f50..c1e72749cc 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -41,6 +41,7 @@ class TestChannelSELayer(unittest.TestCase): + @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelSELayer(**input_param) @@ -60,6 +61,7 @@ def test_ill_arg(self): class TestResidualSELayer(unittest.TestCase): + @parameterized.expand(TEST_CASES[:1]) def test_shape(self, input_param, input_shape, expected_shape): net = ResidualSELayer(**input_param) diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index 23bc63fbf6..6713e7bba9 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -47,6 +47,7 @@ class TestSegLossIntegration(unittest.TestCase): + def setUp(self): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -92,6 +93,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a one layer model class OnelayerNet(nn.Module): + def __init__(self): super().__init__() self.layer_1 = nn.Linear(num_voxels, 200) diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index cb34445efa..728699c434 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -83,6 +83,7 @@ class TestResNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2) def test_shape(self, input_param, input_shape, expected_shape): net = SegResNet(**input_param).to(device) @@ -102,6 +103,7 @@ def test_script(self): class TestResNetVAE(unittest.TestCase): + @parameterized.expand(TEST_CASE_SEGRESNET_VAE) def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param).to(device) diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index 343f39d72c..633507a06a 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -38,6 +38,7 @@ class TestResBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_RESBLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index a5b88f9724..5372fcc8ae 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -72,6 +72,7 @@ class TestResNetDS(unittest.TestCase): + @parameterized.expand(TEST_CASE_SEGRESNET_DS) def test_shape(self, input_param, input_shape, expected_shape): net = SegResNetDS(**input_param).to(device) diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py index 3ab6c0a9c5..c7d19f34ab 100644 --- a/tests/test_select_cross_validation_folds.py +++ b/tests/test_select_cross_validation_folds.py @@ -43,6 +43,7 @@ class TestSelectCrossValidationFolds(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_value(self, input_param, result): partitions = partition_dataset(**input_param) diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py index 5eb4a1c51b..f025917b9d 100644 --- a/tests/test_select_itemsd.py +++ b/tests/test_select_itemsd.py @@ -23,6 +23,7 @@ class TestSelectItemsd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) def test_memory(self, input_param, expected_key_size): input_data = {} diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0d0553ed2c..d52cc71e55 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -46,6 +46,7 @@ class TestResBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_SABLOCK) @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/test_senet.py b/tests/test_senet.py index 92b5f39ace..6809d4562b 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -58,6 +58,7 @@ class TestSENET(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_senet_shape(self, net, net_args): input_data = torch.randn(2, 2, 64, 64, 64).to(device) @@ -75,6 +76,7 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): + def setUp(self): self.original_urls = se_mod.SE_NET_MODELS.copy() replace_url = test_is_quick() diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py index 1797a649e0..d712f05ee1 100644 --- a/tests/test_separable_filter.py +++ b/tests/test_separable_filter.py @@ -20,6 +20,7 @@ class SeparableFilterTestCase(unittest.TestCase): + def test_1d(self): a = torch.tensor([[list(range(10))]], dtype=torch.float) out = separable_filtering(a, torch.tensor([-1, 0, 1])) diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index aab7af1079..7d64aed244 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -21,6 +21,7 @@ class TestSetDeterminism(unittest.TestCase): + def test_values(self): # check system default flags set_determinism(None) @@ -55,6 +56,7 @@ def test_values(self): class TestSetFlag(unittest.TestCase): + def setUp(self): set_determinism(1, use_deterministic_algorithms=True) diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index 993e8a4ac2..b4f44957a2 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -14,16 +14,18 @@ import os import unittest -from tests.utils import skip_if_no_cuda +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda class TestVisibleDevices(unittest.TestCase): + @staticmethod def run_process_and_get_exit_code(code_to_execute): value = os.system(code_to_execute) return int(bin(value).replace("0b", "").rjust(16, "0")[:8], 2) @skip_if_no_cuda + @SkipIfAtLeastPyTorchVersion((2, 2, 1)) def test_visible_devices(self): num_gpus_before = self.run_process_and_get_exit_code( 'python -c "import os; import torch; ' diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py index f1bc36036e..90aa0f9271 100644 --- a/tests/test_shift_intensity.py +++ b/tests/test_shift_intensity.py @@ -20,6 +20,7 @@ class TestShiftIntensity(NumpyImageTestCase2D): + def test_value(self): shifter = ShiftIntensity(offset=1.0) result = shifter(self.imt) diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index e8d163b34a..22336b4415 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -21,6 +21,7 @@ class TestShiftIntensityd(NumpyImageTestCase2D): + def test_value(self): key = "img" for p in TEST_NDARRAYS: diff --git a/tests/test_shuffle_buffer.py b/tests/test_shuffle_buffer.py index 9fcd3a23f6..e75321616b 100644 --- a/tests/test_shuffle_buffer.py +++ b/tests/test_shuffle_buffer.py @@ -23,6 +23,7 @@ @SkipIfBeforePyTorchVersion((1, 12)) class TestShuffleBuffer(unittest.TestCase): + def test_shape(self): buffer = ShuffleBuffer([1, 2, 3, 4], seed=0) num_workers = 2 if sys.platform == "linux" else 0 diff --git a/tests/test_signal_continuouswavelet.py b/tests/test_signal_continuouswavelet.py index 4886168a00..7e6ee8b105 100644 --- a/tests/test_signal_continuouswavelet.py +++ b/tests/test_signal_continuouswavelet.py @@ -29,6 +29,7 @@ @skipUnless(has_pywt, "pywt required") class TestSignalContinousWavelet(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, type, length, frequency): self.assertIsInstance(SignalContinuousWavelet(type, length, frequency), SignalContinuousWavelet) diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index ee606d960c..a3ee623cc5 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -26,6 +26,7 @@ @SkipIfBeforePyTorchVersion((1, 9)) class TestSignalFillEmptyNumpy(unittest.TestCase): + def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = np.load(TEST_SIGNAL) @@ -37,6 +38,7 @@ def test_correct_parameters_multi_channels(self): @SkipIfBeforePyTorchVersion((1, 9)) class TestSignalFillEmptyTorch(unittest.TestCase): + def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = convert_to_tensor(np.load(TEST_SIGNAL)) diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py index 5b12055e7d..ee8c571ef8 100644 --- a/tests/test_signal_fillemptyd.py +++ b/tests/test_signal_fillemptyd.py @@ -26,6 +26,7 @@ @SkipIfBeforePyTorchVersion((1, 9)) class TestSignalFillEmptyNumpy(unittest.TestCase): + def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = np.load(TEST_SIGNAL) @@ -41,6 +42,7 @@ def test_correct_parameters_multi_channels(self): @SkipIfBeforePyTorchVersion((1, 9)) class TestSignalFillEmptyTorch(unittest.TestCase): + def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = convert_to_tensor(np.load(TEST_SIGNAL)) diff --git a/tests/test_signal_rand_add_gaussiannoise.py b/tests/test_signal_rand_add_gaussiannoise.py index 2090df876f..e5c9eba8a2 100644 --- a/tests/test_signal_rand_add_gaussiannoise.py +++ b/tests/test_signal_rand_add_gaussiannoise.py @@ -25,6 +25,7 @@ class TestSignalRandAddGaussianNoiseNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise) @@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries): class TestSignalRandAddGaussianNoiseTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise) diff --git a/tests/test_signal_rand_add_sine.py b/tests/test_signal_rand_add_sine.py index ae0684d608..4ba91247dd 100644 --- a/tests/test_signal_rand_add_sine.py +++ b/tests/test_signal_rand_add_sine.py @@ -25,6 +25,7 @@ class TestSignalRandAddSineNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, freqs): self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine) @@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries, freqs): class TestSignalRandAddSineTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, freqs): self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine) diff --git a/tests/test_signal_rand_add_sine_partial.py b/tests/test_signal_rand_add_sine_partial.py index 109fb006ea..71b67747a2 100644 --- a/tests/test_signal_rand_add_sine_partial.py +++ b/tests/test_signal_rand_add_sine_partial.py @@ -25,6 +25,7 @@ class TestSignalRandAddSinePartialNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction): self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial) @@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fracti class TestSignalRandAddSinePartialTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction): self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial) diff --git a/tests/test_signal_rand_add_squarepulse.py b/tests/test_signal_rand_add_squarepulse.py index efbdc9af09..e1432029ea 100644 --- a/tests/test_signal_rand_add_squarepulse.py +++ b/tests/test_signal_rand_add_squarepulse.py @@ -31,6 +31,7 @@ @skipUnless(has_scipy, "scipy required") @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestSignalRandAddSquarePulseNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies): self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse) @@ -43,6 +44,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies): @skipUnless(has_scipy, "scipy required") @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestSignalRandAddSquarePulseTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies): self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse) diff --git a/tests/test_signal_rand_add_squarepulse_partial.py b/tests/test_signal_rand_add_squarepulse_partial.py index eee3f5596d..7e1c2bb9d8 100644 --- a/tests/test_signal_rand_add_squarepulse_partial.py +++ b/tests/test_signal_rand_add_squarepulse_partial.py @@ -31,6 +31,7 @@ @skipUnless(has_scipy, "scipy required") @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestSignalRandAddSquarePulsePartialNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction): self.assertIsInstance( @@ -45,6 +46,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fracti @skipUnless(has_scipy, "scipy required") @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestSignalRandAddSquarePulsePartialTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction): self.assertIsInstance( diff --git a/tests/test_signal_rand_drop.py b/tests/test_signal_rand_drop.py index 5dcd466481..bf2db75a6a 100644 --- a/tests/test_signal_rand_drop.py +++ b/tests/test_signal_rand_drop.py @@ -25,6 +25,7 @@ class TestSignalRandDropNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop) @@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries): class TestSignalRandDropTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop) diff --git a/tests/test_signal_rand_scale.py b/tests/test_signal_rand_scale.py index 126d7cca65..c040c59a1f 100644 --- a/tests/test_signal_rand_scale.py +++ b/tests/test_signal_rand_scale.py @@ -25,6 +25,7 @@ class TestSignalRandScaleNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale) @@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries): class TestSignalRandScaleTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, boundaries): self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale) diff --git a/tests/test_signal_rand_shift.py b/tests/test_signal_rand_shift.py index ed25cc8b1f..96809e7446 100644 --- a/tests/test_signal_rand_shift.py +++ b/tests/test_signal_rand_shift.py @@ -29,6 +29,7 @@ @skipUnless(has_scipy, "scipy required") class TestSignalRandShiftNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, mode, filling, boundaries): self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift) @@ -40,6 +41,7 @@ def test_correct_parameters_multi_channels(self, mode, filling, boundaries): @skipUnless(has_scipy, "scipy required") class TestSignalRandShiftTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, mode, filling, boundaries): self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift) diff --git a/tests/test_signal_remove_frequency.py b/tests/test_signal_remove_frequency.py index b18de36c08..9f795ce68b 100644 --- a/tests/test_signal_remove_frequency.py +++ b/tests/test_signal_remove_frequency.py @@ -31,6 +31,7 @@ @skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required") class TestSignalRemoveFrequencyNumpy(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq): self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency) @@ -49,6 +50,7 @@ def test_correct_parameters_multi_channels(self, frequency, quality_factor, samp @skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required") class TestSignalRemoveFrequencyTorch(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq): self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency) diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index f18b208e9c..da7540d45e 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -69,6 +69,7 @@ class TestChannelSELayer(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = SimpleASPP(**input_param) diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py index 5cf47b245e..0a4f23450a 100644 --- a/tests/test_simulatedelay.py +++ b/tests/test_simulatedelay.py @@ -22,6 +22,7 @@ class TestSimulateDelay(NumpyImageTestCase2D): + @parameterized.expand([(0.45,), (1,)]) def test_value(self, delay_test_time: float): resize = SimulateDelay(delay_time=delay_test_time) diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py index 827fe69510..419e21f24d 100644 --- a/tests/test_simulatedelayd.py +++ b/tests/test_simulatedelayd.py @@ -22,6 +22,7 @@ class TestSimulateDelay(NumpyImageTestCase2D): + @parameterized.expand([(0.45,), (1,)]) def test_value(self, delay_test_time: float): resize = SimulateDelayd(keys="imgd", delay_time=delay_test_time) diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index 0ac8ef0d7a..5ee166cf10 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -31,6 +31,7 @@ class TestSkipConnection(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param) diff --git a/tests/test_slice_inferer.py b/tests/test_slice_inferer.py index 4d7dea026f..526542943e 100644 --- a/tests/test_slice_inferer.py +++ b/tests/test_slice_inferer.py @@ -23,6 +23,7 @@ class TestSliceInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, spatial_dim): spatial_dim = int(spatial_dim) diff --git a/tests/test_sliding_patch_wsi_dataset.py b/tests/test_sliding_patch_wsi_dataset.py index 518e94552f..6369613426 100644 --- a/tests/test_sliding_patch_wsi_dataset.py +++ b/tests/test_sliding_patch_wsi_dataset.py @@ -213,6 +213,7 @@ def setUpModule(): class SlidingPatchWSIDatasetTests: + class Tests(unittest.TestCase): backend = None @@ -252,6 +253,7 @@ def test_read_patches_large(self, input_parameters, expected): @skipUnless(has_cucim, "Requires cucim") class TestSlidingPatchWSIDatasetCuCIM(SlidingPatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "cucim" @@ -259,6 +261,7 @@ def setUpClass(cls): @skipUnless(has_osl, "Requires openslide") class TestSlidingPatchWSIDatasetOpenSlide(SlidingPatchWSIDatasetTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "openslide" diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 276bd1e372..6fc9240a13 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -36,6 +36,7 @@ class TestSlidingWindowHoVerNetInference(unittest.TestCase): + @parameterized.expand(TEST_CASES_PADDING) def test_sliding_window_with_padding( self, key, image_shape, roi_shape, sw_batch_size, overlap, mode, device, extra_input_padding diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 8f0c074403..33b38a5bc7 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -70,8 +70,10 @@ class TestSlidingWindowInference(unittest.TestCase): + @parameterized.expand(BUFFER_CASES) def test_buffers(self, size_params, buffer_steps, buffer_dim, device_params): + def mult_two(patch, *args, **kwargs): return 2.0 * patch diff --git a/tests/test_sliding_window_splitter.py b/tests/test_sliding_window_splitter.py index 015293cbee..ad136c61a4 100644 --- a/tests/test_sliding_window_splitter.py +++ b/tests/test_sliding_window_splitter.py @@ -236,6 +236,7 @@ def missing_parameter_filter(patch): class SlidingWindowSplitterTests(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_TENSOR_0, diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 0e2a79fef3..bb43060469 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -38,6 +38,7 @@ class TestSmartCacheDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, replace_rate, num_replace_workers, transform): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4)) diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py index c525311478..ca010641c4 100644 --- a/tests/test_smooth_field.py +++ b/tests/test_smooth_field.py @@ -88,6 +88,7 @@ class TestSmoothField(unittest.TestCase): + @parameterized.expand(TESTS_CONTRAST) def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expected_val): g = RandSmoothFieldAdjustContrastd(**input_param) diff --git a/tests/test_soft_clip.py b/tests/test_soft_clip.py new file mode 100644 index 0000000000..de5122e982 --- /dev/null +++ b/tests/test_soft_clip.py @@ -0,0 +1,125 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.utils import soft_clip + +TEST_CASES = [ + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 10}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]), + }, + ], + [ + {"minv": 2, "maxv": None, "sharpness_factor": 10}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]), + }, + ], + [ + {"minv": None, "maxv": 7, "sharpness_factor": 10}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 1.0}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 3.0}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 5.0}, + { + "input": torch.arange(10).float(), + "clipped": torch.tensor([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 10}, + { + "input": np.arange(10).astype(np.float32), + "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]), + }, + ], + [ + {"minv": 2, "maxv": None, "sharpness_factor": 10}, + { + "input": np.arange(10).astype(float), + "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]), + }, + ], + [ + {"minv": None, "maxv": 7, "sharpness_factor": 10}, + { + "input": np.arange(10).astype(float), + "clipped": np.array([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 1.0}, + { + "input": np.arange(10).astype(float), + "clipped": np.array([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 3.0}, + { + "input": np.arange(10).astype(float), + "clipped": np.array([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]), + }, + ], + [ + {"minv": 2, "maxv": 8, "sharpness_factor": 5.0}, + { + "input": np.arange(10).astype(float), + "clipped": np.array([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]), + }, + ], +] + + +class TestSoftClip(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data): + outputs = soft_clip(input_data["input"], **input_param) + expected_val = input_data["clipped"] + if isinstance(outputs, torch.Tensor): + np.testing.assert_allclose( + outputs.detach().cpu().numpy(), expected_val.detach().cpu().numpy(), atol=1e-4, rtol=1e-4 + ) + else: + np.testing.assert_allclose(outputs, expected_val, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_some_of.py b/tests/test_some_of.py index 8880c376b9..3723732d51 100644 --- a/tests/test_some_of.py +++ b/tests/test_some_of.py @@ -31,21 +31,25 @@ class A(Transform): + def __call__(self, x): return 2 * x class B(Transform): + def __call__(self, x): return 3 * x class C(Transform): + def __call__(self, x): return 5 * x class D(Transform): + def __call__(self, x): return 7 * x @@ -71,6 +75,7 @@ def __call__(self, x): class TestSomeOf(unittest.TestCase): + def setUp(self): set_determinism(seed=0) @@ -221,6 +226,7 @@ def test_bad_num_transforms(self): class TestSomeOfAPITests(unittest.TestCase): + @staticmethod def data_from_keys(keys): if keys is None: diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 8b664641d7..c9a6291c78 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -271,6 +271,7 @@ @skip_if_quick class TestSpacingCase(unittest.TestCase): + @parameterized.expand(TESTS) def test_spacing( self, diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 36986b2706..1cecaabced 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -105,6 +105,7 @@ class TestSpacingDCase(unittest.TestCase): + @parameterized.expand(TESTS) def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, device): data = {k: v.to(device) for k, v in data.items()} diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py index 6675a6db67..9353ceedc2 100644 --- a/tests/test_spade_autoencoderkl.py +++ b/tests/test_spade_autoencoderkl.py @@ -12,14 +12,18 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import SPADEAutoencoderKL +from monai.utils import optional_import -CASES = [ +einops, has_einops = optional_import("einops") + +CASES_NO_ATTENTION = [ [ { "spatial_dims": 2, @@ -31,12 +35,36 @@ "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, }, (1, 1, 16, 16), (1, 3, 16, 16), (1, 1, 16, 16), (1, 4, 4, 4), ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +CASES_ATTENTION = [ [ { "spatial_dims": 2, @@ -46,7 +74,7 @@ "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), - "num_res_blocks": (1, 1, 2), + "num_res_blocks": 1, "norm_num_groups": 4, }, (1, 1, 16, 16), @@ -63,7 +91,7 @@ "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), - "num_res_blocks": 1, + "num_res_blocks": (1, 1, 2), "norm_num_groups": 4, }, (1, 1, 16, 16), @@ -79,7 +107,7 @@ "out_channels": 1, "channels": (4, 4, 4), "latent_channels": 4, - "attention_levels": (False, False, True), + "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, }, @@ -96,10 +124,9 @@ "out_channels": 1, "channels": (4, 4, 4), "latent_channels": 4, - "attention_levels": (False, False, False), + "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, }, (1, 1, 16, 16), (1, 3, 16, 16), @@ -118,7 +145,6 @@ "num_res_blocks": 1, "norm_num_groups": 4, "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, }, (1, 1, 16, 16), (1, 3, 16, 16), @@ -164,6 +190,11 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +if has_einops: + CASES = CASES_ATTENTION + CASES_NO_ATTENTION +else: + CASES = CASES_NO_ATTENTION + class TestSPADEAutoEncoderKL(unittest.TestCase): @parameterized.expand(CASES) @@ -174,6 +205,7 @@ def test_shape(self, input_param, input_shape, input_seg, expected_shape, expect self.assertEqual(result[0].shape, expected_shape) self.assertEqual(result[1].shape, expected_latent_shape) + @skipUnless(has_einops, "Requires einops") def test_model_channels_not_multiple_of_norm_num_group(self): with self.assertRaises(ValueError): SPADEAutoencoderKL( @@ -188,6 +220,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): norm_num_groups=16, ) + @skipUnless(has_einops, "Requires einops") def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): SPADEAutoencoderKL( @@ -202,6 +235,7 @@ def test_model_channels_not_same_size_of_attention_levels(self): norm_num_groups=16, ) + @skipUnless(has_einops, "Requires einops") def test_model_channels_not_same_size_of_num_res_blocks(self): with self.assertRaises(ValueError): SPADEAutoencoderKL( @@ -240,6 +274,7 @@ def test_shape_decode(self): result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + @skipUnless(has_einops, "Requires einops") def test_wrong_shape_decode(self): net = SPADEAutoencoderKL( spatial_dims=2, diff --git a/tests/test_spatial_combine_transforms.py b/tests/test_spatial_combine_transforms.py index 8594daed16..8479e9084b 100644 --- a/tests/test_spatial_combine_transforms.py +++ b/tests/test_spatial_combine_transforms.py @@ -132,6 +132,7 @@ class CombineLazyTest(unittest.TestCase): + @parameterized.expand(TEST_2D + TEST_3D) def test_combine_transforms(self, input_shape, funcs): for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py index b513bd0f05..e64b242128 100644 --- a/tests/test_spatial_resample.py +++ b/tests/test_spatial_resample.py @@ -133,6 +133,7 @@ class TestSpatialResample(unittest.TestCase): + @parameterized.expand(TESTS) def test_flips(self, img, device, data_param, expected_output): for p in TEST_NDARRAYS_ALL: diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py index ebe3eb6e4f..d5c86258d7 100644 --- a/tests/test_spatial_resampled.py +++ b/tests/test_spatial_resampled.py @@ -11,6 +11,7 @@ from __future__ import annotations +import platform import unittest import numpy as np @@ -23,6 +24,12 @@ from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, assert_allclose +ON_AARCH64 = platform.machine() == "aarch64" +if ON_AARCH64: + rtol, atol = 1e-1, 1e-2 +else: + rtol, atol = 1e-3, 1e-4 + TESTS = [] destinations_3d = [ @@ -87,6 +94,7 @@ class TestSpatialResample(unittest.TestCase): + @parameterized.expand(TESTS) def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output): img = MetaTensor(img, affine=torch.eye(4)).to(device) @@ -103,7 +111,7 @@ def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output): # check lazy lazy_xform = SpatialResampled(**init_param) - test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img") + test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img", rtol=rtol, atol=atol) # check inverse inverted = xform.inverse(output_data)["img"] diff --git a/tests/test_spectral_loss.py b/tests/test_spectral_loss.py index 21b5c48de4..f62ae9030b 100644 --- a/tests/test_spectral_loss.py +++ b/tests/test_spectral_loss.py @@ -63,6 +63,7 @@ class TestJukeboxLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_results(self, input_param, input_data, expected_val): results = JukeboxLoss(**input_param).forward(**input_data) diff --git a/tests/test_splitdim.py b/tests/test_splitdim.py index 6c678a6bc2..f557f44142 100644 --- a/tests/test_splitdim.py +++ b/tests/test_splitdim.py @@ -26,6 +26,7 @@ class TestSplitDim(unittest.TestCase): + @parameterized.expand(TESTS) def test_correct_shape(self, shape, keepdim, im_type): arr = im_type(np.random.rand(*shape)) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index 130a214345..3f818f905b 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -61,6 +61,7 @@ class TestUnsqueeze(unittest.TestCase): + @parameterized.expand(RIGHT_CASES + ALL_CASES) def test_unsqueeze_right(self, arr, ndim, shape): self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 6673fd25c1..a295d20ef5 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -32,6 +32,7 @@ class TestSqueezeDim(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SqueezeDim(**input_param)(test_data) diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index 9fa9d84030..934479563d 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -80,6 +80,7 @@ class TestSqueezeDim(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SqueezeDimd(**input_param)(test_data) diff --git a/tests/test_ssim_loss.py b/tests/test_ssim_loss.py index db80eb80db..7fa593b956 100644 --- a/tests/test_ssim_loss.py +++ b/tests/test_ssim_loss.py @@ -23,6 +23,7 @@ class TestSSIMLoss(unittest.TestCase): + def test_shape(self): set_determinism(0) preds = torch.abs(torch.randn(2, 3, 16, 16)) diff --git a/tests/test_ssim_metric.py b/tests/test_ssim_metric.py index 467e478937..d79107e999 100644 --- a/tests/test_ssim_metric.py +++ b/tests/test_ssim_metric.py @@ -20,6 +20,7 @@ class TestSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): set_determinism(0) preds = torch.abs(torch.randn(2, 3, 16, 16)) diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index 2037dc3951..22c2836239 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -36,6 +36,7 @@ class TestStateCacher(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_state_cacher(self, data_obj, params): key = "data_obj" diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index af18c18aa2..b4dc1db568 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -21,6 +21,7 @@ class TestStdShiftIntensity(NumpyImageTestCase2D): + def test_value(self): for p in TEST_NDARRAYS: imt = p(self.imt) diff --git a/tests/test_std_shift_intensityd.py b/tests/test_std_shift_intensityd.py index 6cb7d416c7..73617ef4a3 100644 --- a/tests/test_std_shift_intensityd.py +++ b/tests/test_std_shift_intensityd.py @@ -21,6 +21,7 @@ class TestStdShiftIntensityd(NumpyImageTestCase2D): + def test_value(self): key = "img" factor = np.random.rand() diff --git a/tests/test_str2bool.py b/tests/test_str2bool.py index 36f99b4064..af932b1df8 100644 --- a/tests/test_str2bool.py +++ b/tests/test_str2bool.py @@ -17,6 +17,7 @@ class TestStr2Bool(unittest.TestCase): + def test_str_2_bool(self): for i in ("yes", "true", "t", "y", "1", True): self.assertTrue(str2bool(i)) diff --git a/tests/test_str2list.py b/tests/test_str2list.py index b442925fb3..e1531373cb 100644 --- a/tests/test_str2list.py +++ b/tests/test_str2list.py @@ -17,6 +17,7 @@ class TestStr2List(unittest.TestCase): + def test_str_2_list(self): for i in ("1,2,3", "1, 2, 3", "1,2e-0,3.0", [1, 2, 3]): self.assertEqual(str2list(i), [1, 2, 3]) diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index a6de8dd846..5abbe57e11 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -68,6 +68,7 @@ class TestSUBPIXEL(unittest.TestCase): + @parameterized.expand(TEST_CASE_SUBPIXEL) def test_subpixel_shape(self, input_param, input_shape, expected_shape): net = SubpixelUpsample(**input_param) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..903f9bd2ca --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + + def operator(x): + return x + + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + + def operator(x): + return x + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + + def operator(x): + return x + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:, 0, :, :] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 53b0d38bb2..2ef19a4eea 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -24,6 +24,7 @@ class TestAllSurfaceDiceMetrics(unittest.TestCase): + def test_tolerance_euclidean_distance_with_spacing(self): batch_size = 2 n_class = 2 diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 81ddee107b..85db389f80 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -142,6 +142,7 @@ def create_spherical_seg_3d( class TestAllSurfaceMetrics(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): if len(input_data) == 3: diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index e34e5a3c8e..5b33475c7e 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -76,6 +76,7 @@ class TestSWINUNETR(unittest.TestCase): + @parameterized.expand(TEST_CASE_SWIN_UNETR) @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 116897e67d..7db3c3e77a 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -41,6 +41,7 @@ class TestDiceCELoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_create_test_image(self, dim, input_param, expected_img, expected_seg, expected_shape, expected_max_cls): set_determinism(seed=0) diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py index 2a3928f9aa..5a16bb4816 100644 --- a/tests/test_tciadataset.py +++ b/tests/test_tciadataset.py @@ -23,6 +23,7 @@ class TestTciaDataset(unittest.TestCase): + @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") @@ -107,7 +108,7 @@ def _test_dataset(dataset): )[0] shutil.rmtree(os.path.join(testing_dir, collection)) - try: + with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"): TciaDataset( root_dir=testing_dir, collection=collection, @@ -116,8 +117,6 @@ def _test_dataset(dataset): download=False, val_frac=val_frac, ) - except RuntimeError as e: - self.assertTrue(str(e).startswith("Cannot find dataset directory")) if __name__ == "__main__": diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index cbb78ec64d..746ad122b2 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -52,6 +52,7 @@ class TestTestTimeAugmentation(unittest.TestCase): + @staticmethod def get_data(num_examples, input_size, data_type=np.asarray, include_label=True): custom_create_test_image_2d = partial( diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index 831803de8c..902f7a4b1d 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -18,6 +18,7 @@ class TestTextEncoder(unittest.TestCase): + def test_test_encoding_shape(self): with skip_if_downloading_fails(): # test 2D encoder diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index ab5dba77be..2b7da2c0b0 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -24,6 +24,7 @@ class TestDataLoader(unittest.TestCase): + def setUp(self): super().setUp() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index ca9fb244fc..9551dec703 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -36,6 +36,7 @@ class TestThreadContainer(unittest.TestCase): + @SkipIfNoModule("ignite") def test_container(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index 7fb28d413f..97c80eebcd 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -27,6 +27,7 @@ class TestThresholdIntensity(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, in_type, input_param, expected_value): test_data = in_type(np.arange(10)) diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index d5e7e5f517..867ebfe952 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -45,6 +45,7 @@ class TestThresholdIntensityd(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, in_type, input_param, expected_value): test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} diff --git a/tests/test_timedcall_dist.py b/tests/test_timedcall_dist.py index af7cf8720f..a814a99b25 100644 --- a/tests/test_timedcall_dist.py +++ b/tests/test_timedcall_dist.py @@ -50,6 +50,7 @@ def case_1_seconds_bad(arg=None): class TestTimedCall(unittest.TestCase): + def test_good_call(self): output = case_1_seconds() self.assertEqual(output, "good") diff --git a/tests/test_to_contiguous.py b/tests/test_to_contiguous.py index 03733b9775..73a9ca27f6 100644 --- a/tests/test_to_contiguous.py +++ b/tests/test_to_contiguous.py @@ -21,6 +21,7 @@ class TestToContiguous(unittest.TestCase): + def test_contiguous_dict(self): tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1) test_dict = {"test_key": [[1]], 0: np.array(0), 1: np.array([0]), "nested": {"nested": [tochange]}} diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 12a377181d..5a1754e7c5 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -26,6 +26,7 @@ @skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupy(unittest.TestCase): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) test_data = cp.rot90(test_data) diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index e9a3488489..a07ab671e1 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -26,6 +26,7 @@ @skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupyd(unittest.TestCase): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) diff --git a/tests/test_to_device.py b/tests/test_to_device.py index cad2b65316..6a13ffca99 100644 --- a/tests/test_to_device.py +++ b/tests/test_to_device.py @@ -30,6 +30,7 @@ @skip_if_no_cuda class TestToDevice(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, device): converter = ToDevice(device=device, non_blocking=True) diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py index 093c3b0c4d..19c2d0761f 100644 --- a/tests/test_to_deviced.py +++ b/tests/test_to_deviced.py @@ -22,6 +22,7 @@ @skip_if_no_cuda class TestToDeviced(unittest.TestCase): + def test_value(self): device = "cuda:0" data = [{"img": torch.tensor(i)} for i in range(4)] diff --git a/tests/test_to_from_meta_tensord.py b/tests/test_to_from_meta_tensord.py index 470826313a..fe777cec77 100644 --- a/tests/test_to_from_meta_tensord.py +++ b/tests/test_to_from_meta_tensord.py @@ -42,6 +42,7 @@ def rand_string(min_len=5, max_len=10): @unittest.skipIf(config.USE_META_DICT, "skipping not metatensor") class TestToFromMetaTensord(unittest.TestCase): + @staticmethod def get_im(shape=None, dtype=None, device=None): if shape is None: diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index 0c604fb9d4..f92b7c0075 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -25,6 +25,7 @@ class TestToNumpy(unittest.TestCase): + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) @@ -70,7 +71,7 @@ def test_list_tuple(self): assert_allclose(result, np.asarray(test_data), type_test=False) test_data = ((1, 2), (3, 4)) result = ToNumpy(wrap_sequence=False)(test_data) - self.assertTrue(type(result), tuple) + self.assertIsInstance(result, tuple) assert_allclose(result, ((np.asarray(1), np.asarray(2)), (np.asarray(3), np.asarray(4)))) def test_single_value(self): diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index d25bdf14a5..ae9b4c84b3 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -25,6 +25,7 @@ class TestToNumpyd(unittest.TestCase): + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index 52307900af..48dba6fa68 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -44,6 +44,7 @@ class TestToOneHot(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_data, expected_shape, expected_result=None): result = one_hot(**input_data) diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index e4f74f6e1e..352e10bcc1 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -40,6 +40,7 @@ class TestToPIL(unittest.TestCase): + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") def test_value(self, test_data): diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 4eb5999b15..1a0232e134 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -38,6 +38,7 @@ class TestToPIL(unittest.TestCase): + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") def test_values(self, input_param, test_data): diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index cde845c246..50df80128b 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -33,6 +33,7 @@ class TestToTensor(unittest.TestCase): + @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): result = ToTensor(dtype=torch.float32, device="cpu", wrap_sequence=True)(test_data) diff --git a/tests/test_to_tensord.py b/tests/test_to_tensord.py index 82456786fd..1eab7b9485 100644 --- a/tests/test_to_tensord.py +++ b/tests/test_to_tensord.py @@ -34,6 +34,7 @@ class TestToTensord(unittest.TestCase): + @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): test_data = {"img": test_data} diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index ec24f388f1..6f8f231829 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -23,11 +23,13 @@ class TestModule(torch.nn.Module): + def forward(self, x): return x + 10 class TestTorchscript(unittest.TestCase): + def test_save_net_with_metadata(self): """Save a network without metadata to a file.""" m = torch.jit.script(TestModule()) diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 9cd536aa6f..2931b0c1a8 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -55,6 +55,7 @@ class TestTorchVision(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index e913b2b9b1..322cce1161 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -153,6 +153,7 @@ class TestTorchVisionFCModel(unittest.TestCase): + @parameterized.expand( [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + ([TEST_CASE_8] if has_enum else []) @@ -187,6 +188,7 @@ def test_with_pretrained(self, input_param, input_shape, expected_shape, expecte class TestLookup(unittest.TestCase): + def test_get_module(self): net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32, 64), strides=(2, 2, 2, 2)) self.assertEqual(look_up_named_module("", net), net) diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py index b2a6bcafc5..ec09692df9 100644 --- a/tests/test_torchvisiond.py +++ b/tests/test_torchvisiond.py @@ -52,6 +52,7 @@ class TestTorchVisiond(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index 42906c84d2..dd139053e3 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -17,6 +17,7 @@ class _TraceTest(TraceableTransform): + def __call__(self, data): self.push_transform(data) return data @@ -27,6 +28,7 @@ def pop(self, data): class TestTraceable(unittest.TestCase): + def test_default(self): expected_key = "_transforms" a = _TraceTest() diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py index 6136e2f7db..ae99f91363 100644 --- a/tests/test_train_mode.py +++ b/tests/test_train_mode.py @@ -19,6 +19,7 @@ class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): t = torch.rand(1, 1, 4, 4) p = torch.nn.Conv2d(1, 1, 3) diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py index 43b628be80..c69eff4071 100644 --- a/tests/test_trainable_bilateral.py +++ b/tests/test_trainable_bilateral.py @@ -273,6 +273,7 @@ @skip_if_no_cpp_extension class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cpu_precise(self, test_case_description, sigmas, input, expected): # Params to determine the implementation to test @@ -371,6 +372,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expec @skip_if_no_cuda @skip_if_no_cpp_extension class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cuda_precise(self, test_case_description, sigmas, input, expected): # Skip this test diff --git a/tests/test_trainable_joint_bilateral.py b/tests/test_trainable_joint_bilateral.py index 8a9c69bda4..4263683ce2 100644 --- a/tests/test_trainable_joint_bilateral.py +++ b/tests/test_trainable_joint_bilateral.py @@ -358,6 +358,7 @@ @skip_if_no_cpp_extension @skip_if_quick class JointBilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cpu_precise(self, test_case_description, sigmas, input, guide, expected): # Params to determine the implementation to test @@ -481,6 +482,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, guide @skip_if_no_cuda @skip_if_no_cpp_extension class JointBilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_cuda_precise(self, test_case_description, sigmas, input, guide, expected): # Skip this test diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 9ad847cdaa..481c20e285 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -47,6 +47,7 @@ @skip_if_quick class TestTranschex(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): net = Transchex(**input_param) diff --git a/tests/test_transform.py b/tests/test_transform.py index ea738eaac3..9b05133391 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -20,6 +20,7 @@ class FaultyTransform(mt.Transform): + def __call__(self, _): raise RuntimeError @@ -29,6 +30,7 @@ def faulty_lambda(_): class TestTransform(unittest.TestCase): + @classmethod def setUpClass(cls): super(__class__, cls).setUpClass() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 914336668d..5a8dbba83c 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -39,6 +39,7 @@ class TestTransformerBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = TransformerBlock(**input_param) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 0c9ae1c7e3..2f5ccd1235 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -27,6 +27,7 @@ class TestTranspose(unittest.TestCase): + @parameterized.expand(TESTS) def test_transpose(self, im, indices): tr = Transpose(indices) diff --git a/tests/test_transposed.py b/tests/test_transposed.py index ab80520fc9..e7c6ecbe8a 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -30,6 +30,7 @@ class TestTranspose(unittest.TestCase): + @parameterized.expand(TESTS) def test_transpose(self, im, indices): data = {"i": deepcopy(im), "j": deepcopy(im)} diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index d1175f40c5..0365503ea2 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -148,6 +148,7 @@ class TestTverskyLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = TverskyLoss(**input_param).forward(**input_data) @@ -164,17 +165,12 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): TverskyLoss(reduction=None)(chn_input, chn_target) - def test_input_warnings(self): + @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)]) + def test_input_warnings(self, include_background, softmax, to_onehot_y): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): - loss = TverskyLoss(include_background=False) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = TverskyLoss(softmax=True) - loss.forward(chn_input, chn_target) - with self.assertWarns(Warning): - loss = TverskyLoss(to_onehot_y=True) + loss = TverskyLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) loss.forward(chn_input, chn_target) def test_script(self): diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py index fbf0c4fe97..63ce7d58e4 100644 --- a/tests/test_ultrasound_confidence_map_transform.py +++ b/tests/test_ultrasound_confidence_map_transform.py @@ -15,6 +15,7 @@ import numpy as np import torch +from parameterized import parameterized from monai.transforms import UltrasoundConfidenceMapTransform from tests.utils import assert_allclose @@ -518,6 +519,7 @@ class TestUltrasoundConfidenceMapTransform(unittest.TestCase): + def setUp(self): self.input_img_np = np.expand_dims(TEST_INPUT, axis=0) # mock image (numpy array) self.input_mask_np = np.expand_dims(TEST_MASK, axis=0) # mock mask (numpy array) @@ -534,162 +536,92 @@ def test_parameters(self): with self.assertRaises(ValueError): UltrasoundConfidenceMapTransform(sink_mode="unknown") - def test_rgb(self): + @parameterized.expand( + [("all", SINK_ALL_OUTPUT), ("mid", SINK_MID_OUTPUT), ("min", SINK_MIN_OUTPUT), ("mask", SINK_MASK_OUTPUT, True)] + ) + def test_ultrasound_confidence_map_transform(self, sink_mode, expected_output, use_mask=False): # RGB image input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 3, axis=0), axis=0) input_img_rgb_torch = torch.from_numpy(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="all") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode) - transform = UltrasoundConfidenceMapTransform(sink_mode="mid") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + if use_mask: + result_torch = transform(input_img_rgb_torch, self.input_mask_torch) + result_np = transform(input_img_rgb, self.input_mask_np) + else: + result_torch = transform(input_img_rgb_torch) + result_np = transform(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="min") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) - - transform = UltrasoundConfidenceMapTransform(sink_mode="mask") - result_torch = transform(input_img_rgb_torch, self.input_mask_torch) self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb, self.input_mask_np) + assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4) self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4) - def test_multi_channel_2d(self): - # 2D multi-channel image + @parameterized.expand( + [ + ("all", SINK_ALL_OUTPUT), + ("mid", SINK_MID_OUTPUT), + ("min", SINK_MIN_OUTPUT), + ("mask", SINK_MASK_OUTPUT, True), # Adding a flag for mask cases + ] + ) + def test_multi_channel_2d(self, sink_mode, expected_output, use_mask=False): input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 17, axis=0), axis=0) input_img_rgb_torch = torch.from_numpy(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="all") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) - - transform = UltrasoundConfidenceMapTransform(sink_mode="mid") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode) - transform = UltrasoundConfidenceMapTransform(sink_mode="min") - result_torch = transform(input_img_rgb_torch) - self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb) - self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + if use_mask: + result_torch = transform(input_img_rgb_torch, self.input_mask_torch) + result_np = transform(input_img_rgb, self.input_mask_np) + else: + result_torch = transform(input_img_rgb_torch) + result_np = transform(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="mask") - result_torch = transform(input_img_rgb_torch, self.input_mask_torch) self.assertIsInstance(result_torch, torch.Tensor) - assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) - result_np = transform(input_img_rgb, self.input_mask_np) + assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4) self.assertIsInstance(result_np, np.ndarray) - assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4) - def test_non_one_first_dim(self): - # Image without first dimension as 1 + @parameterized.expand([("all",), ("mid",), ("min",), ("mask",)]) + def test_non_one_first_dim(self, sink_mode): + transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode) input_img_rgb = np.repeat(self.input_img_np, 3, axis=0) input_img_rgb_torch = torch.from_numpy(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="all") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb) - - transform = UltrasoundConfidenceMapTransform(sink_mode="mid") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb) - - transform = UltrasoundConfidenceMapTransform(sink_mode="min") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb) - - transform = UltrasoundConfidenceMapTransform(sink_mode="mask") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch, self.input_mask_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb, self.input_mask_np) - - def test_no_first_dim(self): - # Image without first dimension + if sink_mode == "mask": + with self.assertRaises(ValueError): + transform(input_img_rgb_torch, self.input_mask_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb, self.input_mask_np) + else: + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + @parameterized.expand([("all",), ("mid",), ("min",), ("mask",)]) + def test_no_first_dim(self, sink_mode): input_img_rgb = self.input_img_np[0] input_img_rgb_torch = torch.from_numpy(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="all") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb) + transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode) - transform = UltrasoundConfidenceMapTransform(sink_mode="mid") with self.assertRaises(ValueError): transform(input_img_rgb_torch) with self.assertRaises(ValueError): transform(input_img_rgb) - transform = UltrasoundConfidenceMapTransform(sink_mode="min") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb) - - transform = UltrasoundConfidenceMapTransform(sink_mode="mask") - with self.assertRaises(ValueError): - transform(input_img_rgb_torch, self.input_mask_torch) - with self.assertRaises(ValueError): - transform(input_img_rgb, self.input_mask_np) - - def test_sink_all(self): - transform = UltrasoundConfidenceMapTransform(sink_mode="all") - - # This should not raise an exception for torch tensor - result_torch = transform(self.input_img_torch) - self.assertIsInstance(result_torch, torch.Tensor) - - # This should not raise an exception for numpy array - result_np = transform(self.input_img_np) - self.assertIsInstance(result_np, np.ndarray) - - def test_sink_mid(self): - transform = UltrasoundConfidenceMapTransform(sink_mode="mid") - - # This should not raise an exception for torch tensor - result_torch = transform(self.input_img_torch) - self.assertIsInstance(result_torch, torch.Tensor) - - # This should not raise an exception for numpy array - result_np = transform(self.input_img_np) - self.assertIsInstance(result_np, np.ndarray) + if sink_mode == "mask": + with self.assertRaises(ValueError): + transform(input_img_rgb_torch, self.input_mask_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb, self.input_mask_np) - def test_sink_min(self): - transform = UltrasoundConfidenceMapTransform(sink_mode="min") + @parameterized.expand([("all",), ("mid",), ("min",)]) + def test_sink_mode(self, mode): + transform = UltrasoundConfidenceMapTransform(sink_mode=mode) # This should not raise an exception for torch tensor result_torch = transform(self.input_img_torch) diff --git a/tests/test_unet.py b/tests/test_unet.py index 9cb4af3379..1fb98f84b0 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -165,6 +165,7 @@ class TestUNET(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = UNet(**input_param).to(device) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 406d30aa12..46018d2bc0 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -57,6 +57,7 @@ @skip_if_quick class TestUNETR(unittest.TestCase): + @parameterized.expand(TEST_CASE_UNETR) def test_shape(self, input_param, input_shape, expected_shape): net = UNETR(**input_param) diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index 60004be25e..9701557ed6 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -102,6 +102,7 @@ class TestResBasicBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UNETR_BASIC_BLOCK) def test_shape(self, input_param, input_shape, expected_shape): for net in [UnetrBasicBlock(**input_param)]: @@ -124,6 +125,7 @@ def test_script(self): class TestUpBlock(unittest.TestCase): + @parameterized.expand(TEST_UP_BLOCK) def test_shape(self, input_param, input_shape, expected_shape, skip_shape): net = UnetrUpBlock(**input_param) @@ -140,6 +142,7 @@ def test_script(self): class TestPrUpBlock(unittest.TestCase): + @parameterized.expand(TEST_PRUP_BLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = UnetrPrUpBlock(**input_param) diff --git a/tests/test_unified_focal_loss.py b/tests/test_unified_focal_loss.py index 0e7217e2b4..3b868a560e 100644 --- a/tests/test_unified_focal_loss.py +++ b/tests/test_unified_focal_loss.py @@ -38,6 +38,7 @@ class TestAsymmetricUnifiedFocalLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_result(self, input_data, expected_val): loss = AsymmetricUnifiedFocalLoss() diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index a82a31b064..e4890c83bc 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -121,6 +121,7 @@ class TestUpsample(unittest.TestCase): + @parameterized.expand(TEST_CASES + TEST_CASES_EQ + TEST_CASES_EQ2) def test_shape(self, input_param, input_shape, expected_shape): net = UpSample(**input_param) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 619ae8aee3..6e655289e4 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -29,6 +29,7 @@ class TestPytorchNumpyUnification(unittest.TestCase): + def setUp(self) -> None: set_determinism(0) diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index b050983d2c..e957dcfb61 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -108,6 +108,7 @@ class TestVarAutoEncoder(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = VarAutoEncoder(**input_param).to(device) diff --git a/tests/test_varnet.py b/tests/test_varnet.py index 3ec6b0f087..a46d58d6a2 100644 --- a/tests/test_varnet.py +++ b/tests/test_varnet.py @@ -32,6 +32,7 @@ class TestVarNet(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape): net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades).to(device) diff --git a/tests/test_version.py b/tests/test_version.py index 15f8cd36c6..35ce8d9a2f 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -75,6 +75,7 @@ def _pairwise(iterable): class TestVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_compare_leq(self, a, b, expected=True): """Test version_leq with `a` and `b`""" diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py index 790feb51ee..6e344e1caa 100644 --- a/tests/test_video_datasets.py +++ b/tests/test_video_datasets.py @@ -31,6 +31,7 @@ class Base: + class TestVideoDataset(unittest.TestCase): video_source: int | str ds: type[VideoDataset] diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index bb3ff7237a..b641599af2 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -67,6 +67,7 @@ class TestClassActivationMap(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py index 0fbe328c83..e9db0af240 100644 --- a/tests/test_vis_gradbased.py +++ b/tests/test_vis_gradbased.py @@ -21,6 +21,7 @@ class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): if adjoint_info != 42: raise ValueError @@ -48,6 +49,7 @@ def __call__(self, x, adjoint_info): class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, vis_type, model, shape): device = "cuda:0" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 4b554de0aa..325b74b3ce 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -24,6 +24,7 @@ class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): if adjoint_info != 42: raise ValueError @@ -149,6 +150,7 @@ def __call__(self, x, adjoint_info): @skip_if_quick class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, cam_class, input_data, expected_shape): if input_data["model"] == "densenet2d": diff --git a/tests/test_vit.py b/tests/test_vit.py index f911c2d5c9..d27c10f95e 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -61,6 +61,7 @@ @skip_if_quick class TestViT(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): net = ViT(**input_param) @@ -68,75 +69,40 @@ def test_shape(self, input_param, input_shape, expected_shape): result, _ = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_ill_arg(self): - with self.assertRaises(ValueError): - ViT( - in_channels=1, - img_size=(128, 128, 128), - patch_size=(16, 16, 16), - hidden_size=128, - mlp_dim=3072, - num_layers=12, - num_heads=12, - pos_embed="conv", - classification=False, - dropout_rate=5.0, - ) - - with self.assertRaises(ValueError): - ViT( - in_channels=1, - img_size=(32, 32, 32), - patch_size=(64, 64, 64), - hidden_size=512, - mlp_dim=3072, - num_layers=12, - num_heads=8, - pos_embed="perceptron", - classification=False, - dropout_rate=0.3, - ) - - with self.assertRaises(ValueError): - ViT( - in_channels=1, - img_size=(96, 96, 96), - patch_size=(8, 8, 8), - hidden_size=512, - mlp_dim=3072, - num_layers=12, - num_heads=14, - pos_embed="conv", - classification=False, - dropout_rate=0.3, - ) - - with self.assertRaises(ValueError): - ViT( - in_channels=1, - img_size=(97, 97, 97), - patch_size=(4, 4, 4), - hidden_size=768, - mlp_dim=3072, - num_layers=12, - num_heads=8, - pos_embed="perceptron", - classification=True, - dropout_rate=0.3, - ) - + @parameterized.expand( + [ + (1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0), + (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3), + (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3), + (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3), + (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3), + ] + ) + def test_ill_arg( + self, + in_channels, + img_size, + patch_size, + hidden_size, + mlp_dim, + num_layers, + num_heads, + pos_embed, + classification, + dropout_rate, + ): with self.assertRaises(ValueError): ViT( - in_channels=4, - img_size=(96, 96, 96), - patch_size=(16, 16, 16), - hidden_size=768, - mlp_dim=3072, - num_layers=12, - num_heads=12, - pos_embed="perc", - classification=False, - dropout_rate=0.3, + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + classification=classification, + dropout_rate=dropout_rate, ) @parameterized.expand(TEST_CASE_Vit) diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 5e95d3c7fb..c68c583a0e 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -66,6 +66,7 @@ @skip_if_quick class TestVitAutoenc(unittest.TestCase): + def setUp(self): self.threads = torch.get_num_threads() torch.set_num_threads(4) @@ -81,83 +82,30 @@ def test_shape(self, input_param, input_shape, expected_shape): result, _ = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_ill_arg(self): - with self.assertRaises(ValueError): - ViTAutoEnc( - in_channels=1, - img_size=(128, 128, 128), - patch_size=(16, 16, 16), - hidden_size=128, - mlp_dim=3072, - num_layers=12, - num_heads=12, - pos_embed="conv", - dropout_rate=5.0, - ) - - with self.assertRaises(ValueError): - ViTAutoEnc( - in_channels=1, - img_size=(32, 32, 32), - patch_size=(64, 64, 64), - hidden_size=512, - mlp_dim=3072, - num_layers=12, - num_heads=8, - pos_embed="perceptron", - dropout_rate=0.3, - ) - - with self.assertRaises(ValueError): - ViTAutoEnc( - in_channels=1, - img_size=(96, 96, 96), - patch_size=(8, 8, 8), - hidden_size=512, - mlp_dim=3072, - num_layers=12, - num_heads=14, - pos_embed="conv", - dropout_rate=0.3, - ) - - with self.assertRaises(ValueError): - ViTAutoEnc( - in_channels=1, - img_size=(97, 97, 97), - patch_size=(4, 4, 4), - hidden_size=768, - mlp_dim=3072, - num_layers=12, - num_heads=8, - pos_embed="perceptron", - dropout_rate=0.3, - ) - - with self.assertRaises(ValueError): - ViTAutoEnc( - in_channels=4, - img_size=(96, 96, 96), - patch_size=(16, 16, 16), - hidden_size=768, - mlp_dim=3072, - num_layers=12, - num_heads=12, - pos_embed="perc", - dropout_rate=0.3, - ) - + @parameterized.expand( + [ + (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", 0.3), # img_size_too_large_for_patch_size + (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", 0.3), # num_heads_out_of_bound + (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", 0.3), # img_size_not_divisible_by_patch_size + (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", 0.3), # invalid_pos_embed + (4, (96, 96, 96), (9, 9, 9), 768, 3072, 12, 12, "perc", 0.3), # patch_size_not_divisible + # Add more test cases as needed + ] + ) + def test_ill_arg( + self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, pos_embed, dropout_rate + ): with self.assertRaises(ValueError): ViTAutoEnc( - in_channels=4, - img_size=(96, 96, 96), - patch_size=(9, 9, 9), - hidden_size=768, - mlp_dim=3072, - num_layers=12, - num_heads=12, - pos_embed="perc", - dropout_rate=0.3, + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, ) diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 633893ce51..0ebf060434 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -55,6 +55,7 @@ class TestVNet(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_VNET_2D_1, diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 32ff120c5d..4abdd0b050 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -71,6 +71,7 @@ class TestVoteEnsemble(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsemble(**input_param)(img) diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index 17f9d54835..957133d7fc 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -86,6 +86,7 @@ class TestVoteEnsembled(unittest.TestCase): + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsembled(**input_param)(img) diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index 53ef2fc18f..ef420ef20c 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -245,6 +245,7 @@ class TestVOXELMORPH(unittest.TestCase): + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = VoxelMorphUNet(**input_param).to(device) diff --git a/tests/test_warp.py b/tests/test_warp.py index e614973f90..bac595224f 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -106,6 +106,7 @@ @skip_if_quick class TestWarp(unittest.TestCase): + def setUp(self): config = testing_data_config("images", "Prostate_T2W_AX_1") download_url_or_skip_test( diff --git a/tests/test_watershed.py b/tests/test_watershed.py index a5a232ba3c..3f7a29bfe7 100644 --- a/tests/test_watershed.py +++ b/tests/test_watershed.py @@ -43,6 +43,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestWatershed(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, args, image, hover_map, expected_shape): mask = GenerateWatershedMask()(image) diff --git a/tests/test_watershedd.py b/tests/test_watershedd.py index c12f5ad140..fc44996be4 100644 --- a/tests/test_watershedd.py +++ b/tests/test_watershedd.py @@ -48,6 +48,7 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") @unittest.skipUnless(has_scipy, "Requires scipy library.") class TestWatershedd(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, args, image, hover_map, expected_shape): data = {"output": image, "hover_map": hover_map} diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py index 376faacc56..a682ec6cc9 100644 --- a/tests/test_weight_init.py +++ b/tests/test_weight_init.py @@ -32,6 +32,7 @@ class TestWeightInit(unittest.TestCase): + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape): im = torch.rand(input_shape) diff --git a/tests/test_weighted_random_sampler_dist.py b/tests/test_weighted_random_sampler_dist.py index d38bab54f0..8e37482da6 100644 --- a/tests/test_weighted_random_sampler_dist.py +++ b/tests/test_weighted_random_sampler_dist.py @@ -24,6 +24,7 @@ @skip_if_windows @skip_if_darwin class DistributedWeightedRandomSamplerTest(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) def test_sampling(self): data = [1, 2, 3, 4, 5] diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index ec55654f07..427f64c705 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -19,6 +19,7 @@ class TestWithAllowMissingKeysMode(unittest.TestCase): + def setUp(self): self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py index 4f61e43fe1..1013f15d85 100644 --- a/tests/test_write_metrics_reports.py +++ b/tests/test_write_metrics_reports.py @@ -23,6 +23,7 @@ class TestWriteMetricsReports(unittest.TestCase): + def test_content(self): with tempfile.TemporaryDirectory() as tempdir: write_metrics_reports( diff --git a/tests/test_wsi_sliding_window_splitter.py b/tests/test_wsi_sliding_window_splitter.py index ac1a136489..c510ece272 100644 --- a/tests/test_wsi_sliding_window_splitter.py +++ b/tests/test_wsi_sliding_window_splitter.py @@ -102,6 +102,7 @@ # Filtering functions test cases def gen_location_filter(locations): + def my_filter(patch, loc): if loc in locations: return False @@ -198,6 +199,7 @@ def setUpModule(): class WSISlidingWindowSplitterTests(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_WSI_0_BASE, diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index aae2b0dbaf..99a86c5ac8 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -402,6 +402,7 @@ def setUpModule(): class WSIReaderTests: + class Tests(unittest.TestCase): backend = None @@ -640,6 +641,7 @@ def test_errors(self, file_path, reader_kwargs, patch_info, exception): @skipUnless(has_cucim, "Requires cucim") class TestCuCIM(WSIReaderTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "cucim" @@ -647,6 +649,7 @@ def setUpClass(cls): @skipUnless(has_osl, "Requires openslide") class TestOpenSlide(WSIReaderTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "openslide" @@ -654,6 +657,7 @@ def setUpClass(cls): @skipUnless(has_tiff, "Requires tifffile") class TestTiffFile(WSIReaderTests.Tests): + @classmethod def setUpClass(cls): cls.backend = "tifffile" diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index c4c7fad5da..de7fad48da 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -256,6 +256,7 @@ @unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") class ZarrAvgMergerTests(unittest.TestCase): + @parameterized.expand( [ TEST_CASE_0_DEFAULT_DTYPE, diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py index de8a8e80d6..2939ff3f49 100644 --- a/tests/test_zipdataset.py +++ b/tests/test_zipdataset.py @@ -20,6 +20,7 @@ class Dataset_(torch.utils.data.Dataset): + def __init__(self, length, index_only=True): self.len = length self.index_only = index_only @@ -48,6 +49,7 @@ def __getitem__(self, index): class TestZipDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, datasets, transform, expected_output, expected_length): test_dataset = ZipDataset(datasets=datasets, transform=transform) diff --git a/tests/test_zoom.py b/tests/test_zoom.py index e1ea3c25a3..2db2df4486 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -43,6 +43,7 @@ class TestZoom(NumpyImageTestCase2D): + @parameterized.expand(VALID_CASES) def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False): im = MetaTensor(self.imt[0], meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index dc39a4f1c2..ae8e688d96 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -64,6 +64,7 @@ class TestZoomAffine(unittest.TestCase): + @parameterized.expand(VALID_CASES) def test_correct(self, affine, scale, expected): output = zoom_affine(affine, scale, diagonal=False) diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 1dcbf98572..ad91f398ff 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -34,6 +34,7 @@ class TestZoomd(NumpyImageTestCase2D): + @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode, keep_size, align_corners=None): key = "img" diff --git a/tests/testing_data/fl_infer_properties.json b/tests/testing_data/fl_infer_properties.json new file mode 100644 index 0000000000..72e97cd2c6 --- /dev/null +++ b/tests/testing_data/fl_infer_properties.json @@ -0,0 +1,67 @@ +{ + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "dataset_dir": { + "description": "directory path of the dataset.", + "required": true, + "id": "dataset_dir" + }, + "dataset": { + "description": "PyTorch dataset object for the inference / evaluation logic.", + "required": true, + "id": "dataset" + }, + "evaluator": { + "description": "inference / evaluation workflow engine.", + "required": true, + "id": "evaluator" + }, + "network_def": { + "description": "network module for the inference.", + "required": true, + "id": "network_def" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + }, + "dataset_data": { + "description": "data source for the inference / evaluation dataset.", + "required": false, + "id": "dataset::data", + "refer_id": null + }, + "handlers": { + "description": "event-handlers for the inference / evaluation logic.", + "required": false, + "id": "handlers", + "refer_id": "evaluator::val_handlers" + }, + "preprocessing": { + "description": "preprocessing for the input data.", + "required": false, + "id": "preprocessing", + "refer_id": "dataset::transform" + }, + "postprocessing": { + "description": "postprocessing for the model output data.", + "required": false, + "id": "postprocessing", + "refer_id": "evaluator::postprocessing" + }, + "key_metric": { + "description": "the key metric during evaluation.", + "required": false, + "id": "key_metric", + "refer_id": "evaluator::key_val_metric" + } +} diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index c0dd973418..e02b9ae995 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -600,6 +600,62 @@ ], } }, + { # test answers for 24.03 + "integration_segmentation_3d": { + "losses": [ + 0.5442982316017151, + 0.4741817444562912, + 0.4535954713821411, + 0.44163046181201937, + 0.4307525992393494, + 0.428487154841423, + ], + "best_metric": 0.9314384460449219, + "infer_metric": 0.9315622448921204, + "output_sums": [ + 0.14268704426414708, + 0.1528672845845743, + 0.1521782248125706, + 0.14028769128068194, + 0.1889830671664784, + 0.16999075690664475, + 0.14736282992708227, + 0.16877952654821815, + 0.15779597155181269, + 0.17987829927082263, + 0.16320253928314676, + 0.16854299322173155, + 0.14497470986956967, + 0.11437140546369519, + 0.1624117412960871, + 0.20156009294443875, + 0.1764654154256958, + 0.0982348259217418, + 0.1942436068604293, + 0.20359421536407518, + 0.19661953116976483, + 0.2088326101468625, + 0.16273043545239807, + 0.1326107887439663, + 0.1489245275752285, + 0.143107476635514, + 0.23189027677929547, + 0.1613818424566088, + 0.14889532196775188, + 0.10332622984492143, + 0.11940054688302351, + 0.13040496302762658, + 0.11472123087193181, + 0.15307044007394474, + 0.16371989575844717, + 0.1942898223272055, + 0.2230120930471398, + 0.1814679187634795, + 0.19069496508164732, + 0.07537197031940022, + ], + } + }, ] diff --git a/tests/utils.py b/tests/utils.py index ee800598bb..ea73a3ed81 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -677,6 +677,7 @@ def setUp(self): class TorchImageTestCase2D(NumpyImageTestCase2D): + def setUp(self): NumpyImageTestCase2D.setUp(self) self.imt = torch.tensor(self.imt) @@ -707,6 +708,7 @@ def setUp(self): class TorchImageTestCase3D(NumpyImageTestCase3D): + def setUp(self): NumpyImageTestCase3D.setUp(self) self.imt = torch.tensor(self.imt) From 1a57b551474cd1740065862113331a9dddce84ca Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 10 May 2024 13:16:32 +0100 Subject: [PATCH 15/32] 7227 refactor transformer and diffusion model unet (#7715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of #7227 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: binliu Signed-off-by: dependabot[bot] Signed-off-by: axel.vlaminck Signed-off-by: monai-bot Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: Suraj Pai Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: Ishan Dutta Signed-off-by: John Zielke Signed-off-by: Mingxin Zheng Signed-off-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Signed-off-by: Yiheng Wang Signed-off-by: Szabolcs Botond Lorincz Molnar Signed-off-by: Lucas Robinet Signed-off-by: Mingxin Signed-off-by: Han Wang Signed-off-by: Konstantin Sukharev Signed-off-by: Ben Murray Signed-off-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Signed-off-by: Mark Graham Signed-off-by: Peter Kaplinsky Signed-off-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Signed-off-by: NabJa Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: axel.vlaminck Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl Co-authored-by: Suraj Pai Co-authored-by: Juampa <1523654+juampatronics@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Ishan Dutta Co-authored-by: johnzielke Co-authored-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond Co-authored-by: Nic Ma Co-authored-by: Lucas Robinet Co-authored-by: Han Wang Co-authored-by: Konstantin Sukharev <50718389+k-sukharev@users.noreply.github.com> Co-authored-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Co-authored-by: Pkaps25 <43655728+Pkaps25@users.noreply.github.com> Co-authored-by: Peter Kaplinsky Co-authored-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Co-authored-by: NabJa <32510324+NabJa@users.noreply.github.com> --- .github/workflows/pythonapp-min.yml | 2 + .github/workflows/pythonapp.yml | 6 +- .pre-commit-config.yaml | 2 +- Dockerfile | 9 +- docs/source/networks.rst | 1 + monai/apps/utils.py | 7 +- monai/bundle/workflows.py | 10 +- monai/fl/client/monai_algo.py | 14 +- monai/fl/utils/constants.py | 1 + monai/losses/ds_loss.py | 2 +- monai/networks/blocks/__init__.py | 2 + monai/networks/blocks/crossattention.py | 166 +++++ monai/networks/blocks/selfattention.py | 53 +- monai/networks/blocks/spade_norm.py | 2 +- monai/networks/blocks/spatialattention.py | 82 +++ monai/networks/blocks/transformerblock.py | 28 +- monai/networks/nets/attentionunet.py | 12 +- monai/networks/nets/autoencoderkl.py | 66 +- monai/networks/nets/controlnet.py | 9 - monai/networks/nets/diffusion_model_unet.py | 578 ++++++------------ monai/networks/nets/resnet.py | 1 - monai/networks/nets/spade_autoencoderkl.py | 8 +- .../nets/spade_diffusion_model_unet.py | 123 ++-- monai/networks/nets/transformer.py | 267 ++------ monai/networks/utils.py | 1 + monai/utils/misc.py | 2 +- requirements-dev.txt | 4 +- tests/test_attentionunet.py | 20 + tests/test_autoencoderkl.py | 37 +- tests/test_bundle_ckpt_export.py | 6 +- tests/test_bundle_get_data.py | 15 +- tests/test_bundle_trt_export.py | 12 +- tests/test_bundle_workflow.py | 6 +- tests/test_clip_intensity_percentilesd.py | 2 +- tests/test_component_store.py | 8 +- tests/test_compute_ho_ver_maps.py | 4 +- tests/test_compute_ho_ver_maps_d.py | 4 +- tests/test_compute_regression_metrics.py | 10 +- tests/test_concat_itemsd.py | 8 +- tests/test_config_parser.py | 2 +- tests/test_controlnet.py | 5 + tests/test_controlnet_inferers.py | 19 + tests/test_crossattention.py | 131 ++++ tests/test_cucim_dict_transform.py | 16 +- tests/test_cucim_transform.py | 16 +- tests/test_detect_envelope.py | 2 +- tests/test_diffusion_inferer.py | 10 + tests/test_diffusion_model_unet.py | 50 ++ tests/test_ensure_typed.py | 32 +- tests/test_flipd.py | 2 +- tests/test_freeze_layers.py | 8 +- tests/test_generalized_dice_loss.py | 4 +- tests/test_get_package_version.py | 6 +- tests/test_grid_patch.py | 6 +- tests/test_handler_stats.py | 16 +- tests/test_integration_bundle_run.py | 6 +- tests/test_inverse_collation.py | 2 +- tests/test_invertd.py | 2 +- tests/test_latent_diffusion_inferer.py | 12 + tests/test_load_imaged.py | 2 +- tests/test_load_spacing_orientation.py | 4 +- tests/test_look_up_option.py | 2 +- tests/test_matshow3d.py | 2 +- tests/test_median_filter.py | 2 +- tests/test_mednistdataset.py | 2 +- tests/test_meta_affine.py | 4 +- tests/test_meta_tensor.py | 4 +- tests/test_mmar_download.py | 2 +- tests/test_persistentdataset.py | 2 +- tests/test_rand_affined.py | 2 +- tests/test_rand_bias_field.py | 2 +- tests/test_rand_weighted_cropd.py | 2 +- tests/test_recon_net_utils.py | 2 +- tests/test_reg_loss_integration.py | 2 +- tests/test_resnet.py | 10 +- tests/test_selfattention.py | 55 ++ tests/test_sobel_gradient.py | 4 +- tests/test_sobel_gradientd.py | 4 +- tests/test_spade_diffusion_model_unet.py | 16 + tests/test_spatialattention.py | 55 ++ tests/test_threadcontainer.py | 2 +- tests/test_to_cupy.py | 16 +- tests/test_to_numpy.py | 12 +- tests/test_torchvision_fc_model.py | 4 +- tests/test_traceable_transform.py | 4 +- tests/test_transformer.py | 36 ++ tests/test_transformerblock.py | 29 +- tests/test_vqvaetransformer_inferer.py | 11 + tests/test_warp.py | 2 +- tests/testing_data/data_config.json | 15 + 90 files changed, 1327 insertions(+), 921 deletions(-) create mode 100644 monai/networks/blocks/crossattention.py create mode 100644 monai/networks/blocks/spatialattention.py create mode 100644 tests/test_crossattention.py create mode 100644 tests/test_spatialattention.py diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index bbe7579774..dffae10558 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -9,6 +9,8 @@ on: - main - releasing/* pull_request: + head_ref-ignore: + - dev concurrency: # automatically cancel the previously triggered workflows when there's a newer version diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index b7f2cfb9db..b8b73907d4 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -9,6 +9,8 @@ on: - main - releasing/* pull_request: + head_ref-ignore: + - dev concurrency: # automatically cancel the previously triggered workflows when there's a newer version @@ -68,10 +70,10 @@ jobs: maximum-size: 16GB disk-root: "D:" - uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' - name: Prepare pip wheel run: | which python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b71a2bac43..b9debaf08f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,7 +69,7 @@ repos: )$ - repo: https://github.com/hadialqattan/pycln - rev: v2.1.3 + rev: v2.4.0 hooks: - id: pycln args: [--config=pyproject.toml] diff --git a/Dockerfile b/Dockerfile index fc97227351..10931222dd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,11 +18,10 @@ LABEL maintainer="monai.contact@gmail.com" # TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431) RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ - cd /opt && \ - git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \ - pip wheel numcodecs && \ - rm -r /opt/*.whl && \ - rm -rf /opt/numcodecs; \ + export CFLAGS="-O3" && \ + export DISABLE_NUMCODECS_SSE2=true && \ + export DISABLE_NUMCODECS_AVX2=true && \ + pip install numcodecs; \ fi WORKDIR /opt/monai diff --git a/docs/source/networks.rst b/docs/source/networks.rst index c51f5c88b1..8321fed1a4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -426,6 +426,7 @@ Layers .. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer :members: +======= `ConjugateGradient` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: ConjugateGradient diff --git a/monai/apps/utils.py b/monai/apps/utils.py index db541923b5..0c998146a3 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -135,7 +135,12 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5 logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") return True actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) - actual_hash = actual_hash_func() + + if sys.version_info >= (3, 9): + actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines + else: + actual_hash = actual_hash_func() + try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 471088994b..b42852cb0f 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -239,6 +239,7 @@ class ConfigWorkflow(BundleWorkflow): logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + If False, the logging logic for the bundle will not be modified. init_id: ID name of the expected config expression to initialize before running, default to "initialize". allow a config to have no `initialize` logic and the ID. run_id: ID name of the expected config expression to run, default to "run". @@ -278,7 +279,7 @@ def __init__( self, config_file: str | Sequence[str], meta_file: str | Sequence[str] | None = None, - logging_file: str | None = None, + logging_file: str | bool | None = None, init_id: str = "initialize", run_id: str = "run", final_id: str = "finalize", @@ -307,7 +308,10 @@ def __init__( super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path) self.config_root_path = config_root_path logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file - if logging_file is not None: + + if logging_file is False: + logger.warn(f"Logging file is set to {logging_file}, skipping logging.") + else: if not os.path.isfile(logging_file): if logging_file == str(self.config_root_path / "logging.conf"): logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") @@ -315,7 +319,7 @@ def __init__( raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") else: logger.info(f"Setting logging properties based on config: {logging_file}.") - fileConfig(logging_file, disable_existing_loggers=False) + fileConfig(str(logging_file), disable_existing_loggers=False) self.parser = ConfigParser() self.parser.read_config(f=config_file) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 9acf131bd9..a3ac58c221 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -134,12 +134,14 @@ def initialize(self, extra=None): Args: extra: Dict with additional information that should be provided by FL system, - i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`. + i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`. + You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False. """ if extra is None: extra = {} self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname") + logging_file = extra.get(ExtraItems.LOGGING_FILE, None) self.logger.info(f"Initializing {self.client_name} ...") # FL platform needs to provide filepath to configuration files @@ -149,7 +151,7 @@ def initialize(self, extra=None): if self.workflow is None: config_train_files = self._add_config_files(self.config_train_filename) self.workflow = ConfigWorkflow( - config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train" + config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train" ) self.workflow.initialize() self.workflow.bundle_root = self.bundle_root @@ -412,13 +414,15 @@ def initialize(self, extra=None): Args: extra: Dict with additional information that should be provided by FL system, - i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`. + i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`. + You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False. """ self._set_cuda_device() if extra is None: extra = {} self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname") + logging_file = extra.get(ExtraItems.LOGGING_FILE, None) timestamp = time.strftime("%Y%m%d_%H%M%S") self.logger.info(f"Initializing {self.client_name} ...") # FL platform needs to provide filepath to configuration files @@ -434,7 +438,7 @@ def initialize(self, extra=None): self.train_workflow = ConfigWorkflow( config_file=config_train_files, meta_file=None, - logging_file=None, + logging_file=logging_file, workflow_type="train", **self.train_kwargs, ) @@ -459,7 +463,7 @@ def initialize(self, extra=None): self.eval_workflow = ConfigWorkflow( config_file=config_eval_files, meta_file=None, - logging_file=None, + logging_file=logging_file, workflow_type=self.eval_workflow_name, **self.eval_kwargs, ) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index eda1a6b4f9..18beceeaee 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -30,6 +30,7 @@ class ExtraItems(StrEnum): CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" STATS_SENDER = "fl_stats_sender" + LOGGING_FILE = "logging_file" class FlPhase(StrEnum): diff --git a/monai/losses/ds_loss.py b/monai/losses/ds_loss.py index 57fcff6b87..aacc16874d 100644 --- a/monai/losses/ds_loss.py +++ b/monai/losses/ds_loss.py @@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] | weight_mode: {``"same"``, ``"exp"``, ``"two"``} Specifies the weights calculation for each image level. Defaults to ``"exp"``. - ``"same"``: all weights are equal to 1. - - ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc . + - ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc . - ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used regardless of the weight_mode diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index afb6664bd9..47abc4a1c4 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -17,6 +17,7 @@ from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool @@ -31,6 +32,7 @@ from .segresnet_block import ResBlock from .selfattention import SABlock from .spade_norm import SPADE +from .spatialattention import SpatialAttentionBlock from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py new file mode 100644 index 0000000000..dc1d5d388e --- /dev/null +++ b/monai/networks/blocks/crossattention.py @@ -0,0 +1,166 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_rel_pos_embedding_layer +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class CrossAttentionBlock(nn.Module): + """ + A cross-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + hidden_input_size: int | None = None, + context_input_size: int | None = None, + dim_head: int | None = None, + qkv_bias: bool = False, + save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + """ + Args: + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. + attention_dtype: cast attention operations to this dtype. + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if dim_head: + inner_size = num_heads * dim_head + self.head_dim = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + inner_size = hidden_size + self.head_dim = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + self.num_heads = num_heads + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.context_input_size = context_input_size if context_input_size else hidden_size + self.out_proj = nn.Linear(inner_size, self.hidden_input_size) + # key, query, value projections + self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) + self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + + self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + + self.scale = self.head_dim**-0.5 + self.save_attn = save_attn + self.attention_dtype = attention_dtype + + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C + + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + q = self.to_q(x) + kv = context if context is not None else x + _, kv_t, _ = kv.size() + k = self.to_k(kv) + v = self.to_v(kv) + + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) + + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + att_mat = att_mat.softmax(dim=-1) + + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() + + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + x = self.out_proj(x) + x = self.drop_output(x) + return x diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3bef24b4e8..370ad38595 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -34,22 +34,32 @@ def __init__( hidden_size: int, num_heads: int, dropout_rate: float = 0.0, + hidden_input_size: int | None = None, + dim_head: int | None = None, qkv_bias: bool = False, save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, ) -> None: """ Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762). + sequence_length: if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + attention_dtype: cast attention operations to this dtype. """ @@ -58,22 +68,43 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - if hidden_size % num_heads != 0: - raise ValueError("hidden size should be divisible by num_heads.") + if dim_head: + inner_dim = num_heads * dim_head + self.dim_head = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + inner_dim = hidden_size + self.dim_head = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") self.num_heads = num_heads - self.out_proj = nn.Linear(hidden_size, hidden_size) - self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.out_proj = nn.Linear(inner_dim, self.hidden_input_size) + self.qkv = nn.Linear(self.hidden_input_size, inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.dim_head**-0.5 self.save_attn = save_attn + self.attention_dtype = attention_dtype + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + self.att_mat = torch.Tensor() self.rel_positional_embedding = ( - get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads) if rel_pos_embedding is not None else None ) @@ -89,11 +120,17 @@ def forward(self, x: torch.Tensor): """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py index b1046f3154..8e082defe0 100644 --- a/monai/networks/blocks/spade_norm.py +++ b/monai/networks/blocks/spade_norm.py @@ -85,7 +85,7 @@ def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: """ # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) + normalized = self.param_free_norm(x.contiguous()) # Part 2. produce scaling and bias conditioned on semantic map segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py new file mode 100644 index 0000000000..020d8d23fd --- /dev/null +++ b/monai/networks/blocks/spatialattention.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn + +from monai.networks.blocks import SABlock +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class SpatialAttentionBlock(nn.Module): + """Perform spatial self-attention on the input tensor. + + The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then + self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + num_channels: number of input channels. Must be divisible by num_head_channels. + num_head_channels: number of channels per head. + attention_dtype: cast attention operations to this dtype. + + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + # check num_head_channels is divisible by num_channels + if num_head_channels is not None and num_channels % num_head_channels != 0: + raise ValueError("num_channels must be divisible by num_head_channels") + num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.attn = SABlock( + hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + ) + + def forward(self, x: torch.Tensor): + residual = x + + if self.spatial_dims == 1: + h = x.shape[2] + rearrange_input = Rearrange("b c h -> b h c") + rearrange_output = Rearrange("b h c -> b c h", h=h) + if self.spatial_dims == 2: + h, w = x.shape[2], x.shape[3] + rearrange_input = Rearrange("b c h w -> b (h w) c") + rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) + if self.spatial_dims == 3: + h, w, d = x.shape[2], x.shape[3], x.shape[4] + rearrange_input = Rearrange("b c h w d -> b (h w d) c") + rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) + + x = self.norm(x) + x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C + + x = self.attn(x) + x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x + residual + return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ddf959dad2..2458902cba 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,10 +11,10 @@ from __future__ import annotations +import torch import torch.nn as nn -from monai.networks.blocks.mlp import MLPBlock -from monai.networks.blocks.selfattention import SABlock +from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock class TransformerBlock(nn.Module): @@ -31,6 +31,9 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, ) -> None: """ Args: @@ -53,10 +56,27 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) + self.attn = SABlock( + hidden_size, + num_heads, + dropout_rate, + qkv_bias=qkv_bias, + save_attn=save_attn, + causal=causal, + sequence_length=sequence_length, + ) self.norm2 = nn.LayerNorm(hidden_size) + self.with_cross_attention = with_cross_attention - def forward(self, x): + if self.with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) return x diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 5689cf1071..fdf31d9701 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -29,7 +29,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: int = 3, + kernel_size: Sequence[int] | int = 3, strides: int = 1, dropout=0.0, ): @@ -219,7 +219,13 @@ def __init__( self.kernel_size = kernel_size self.dropout = dropout - head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + head = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + dropout=dropout, + kernel_size=self.kernel_size, + ) reduce_channels = Convolution( spatial_dims=spatial_dims, in_channels=channels[0], @@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: out_channels=channels[1], strides=strides[0], dropout=self.dropout, + kernel_size=self.kernel_size, ), subblock, ), @@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) - out_channels=out_channels, strides=strides, dropout=self.dropout, + kernel_size=self.kernel_size, ), up_kernel_size=self.up_kernel_size, strides=strides, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 372e704d53..17bb90d6f6 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution, Upsample -from monai.networks.blocks.selfattention import SABlock +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample from monai.utils import ensure_tuple_rep, optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -144,61 +143,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class AttentionBlock(nn.Module): - """Perform spatial self-attention on the input tensor. - - The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - num_channels: number of input channels. Must be divisible by num_head_channels. - num_head_channels: number of channels per head. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - - self.spatial_dims = spatial_dims - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - # check num_head_channels is divisible by num_channels - if num_head_channels is not None and num_channels % num_head_channels != 0: - raise ValueError("num_channels must be divisible by num_head_channels") - num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - - self.attn = SABlock(hidden_size=num_channels, num_heads=num_heads, qkv_bias=True) - - def forward(self, x: torch.Tensor): - residual = x - - if self.spatial_dims == 1: - h = x.shape[2] - rearrange_input = Rearrange("b c h -> b h c") - rearrange_output = Rearrange("b h c -> b c h", h=h) - if self.spatial_dims == 2: - h, w = x.shape[2], x.shape[3] - rearrange_input = Rearrange("b c h w -> b (h w) c") - rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - if self.spatial_dims == 3: - h, w, d = x.shape[2], x.shape[3], x.shape[4] - rearrange_input = Rearrange("b c h w d -> b (h w d) c") - rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - - x = self.norm(x) - x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - - x = self.attn(x) - x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] - x = x + residual - return x - - class Encoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -271,7 +215,7 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, @@ -294,7 +238,7 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=channels[-1], norm_num_groups=norm_num_groups, @@ -401,7 +345,7 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -440,7 +384,7 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index d98755f401..7450c87314 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -141,7 +141,6 @@ class ControlNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. """ @@ -162,7 +161,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), ) -> None: @@ -209,11 +207,6 @@ def __init__( f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." ) - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.num_res_blocks = num_res_blocks @@ -289,7 +282,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -334,7 +326,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 0441cc9cfe..38d7f816a9 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -35,17 +35,13 @@ from collections.abc import Sequence import torch -import torch.nn.functional as F from torch import nn -from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample from monai.networks.layers.factories import Pool from monai.utils import ensure_tuple_rep, optional_import -# To install xformers, use pip install xformers==0.0.16rc401 - -xops, has_xformers = optional_import("xformers.ops") - +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") __all__ = ["DiffusionModelUNet"] @@ -59,122 +55,9 @@ def zero_module(module: nn.Module) -> nn.Module: return module -class _CrossAttention(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - output: torch.Tensor = self.to_out(x) - return output - - -class _BasicTransformerBlock(nn.Module): +class DiffusionUNetTransformerBlock(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A basic Transformer block. + A Transformer block that allows for the input dimension to differ from the hidden dimension. Args: num_channels: number of channels in the input and output. @@ -183,7 +66,7 @@ class _BasicTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ def __init__( @@ -194,27 +77,26 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, ) -> None: super().__init__() - self.attn1 = _CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention + self.attn1 = SABlock( + hidden_size=num_attention_heads * num_head_channels, + hidden_input_size=num_channels, + num_heads=num_attention_heads, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) - self.attn2 = _CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None + self.attn2 = CrossAttentionBlock( + hidden_size=num_attention_heads * num_head_channels, + num_heads=num_attention_heads, + hidden_input_size=num_channels, + context_input_size=cross_attention_dim, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) self.norm3 = nn.LayerNorm(num_channels) @@ -231,7 +113,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch return x -class _SpatialTransformer(nn.Module): +class SpatialTransformer(nn.Module): """ NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make use of this block as support is not guaranteed. For more information see: @@ -251,7 +133,6 @@ class _SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -266,7 +147,6 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -287,14 +167,13 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - _BasicTransformerBlock( + DiffusionUNetTransformerBlock( num_channels=inner_dim, num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -343,126 +222,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch return x + residual -class _AttentionBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x.contiguous()) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: """ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic @@ -490,12 +249,8 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri return embedding -class _Downsample(nn.Module): +class DiffusionUnetDownsample(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Downsampling layer. Args: @@ -541,68 +296,19 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Ten return output -class _Upsample(nn.Module): +class WrappedUpsample(Upsample): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. + Wraps MONAI upsample block to allow for calling with timestep embeddings. """ - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + upsampled: torch.Tensor = super().forward(x) + return upsampled - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - if self.use_conv: - x = self.conv(x) - return x - - -class _ResnetBlock(nn.Module): +class DiffusionUNetResnetBlock(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 Residual block with timestep conditioning. Args: @@ -649,9 +355,17 @@ def __init__( self.upsample = self.downsample = None if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) @@ -749,7 +463,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -764,7 +478,7 @@ def __init__( if add_downsample: self.downsampler: nn.Module | None if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -774,7 +488,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -817,7 +531,6 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -833,7 +546,6 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -844,7 +556,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -854,13 +566,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -870,7 +581,7 @@ def __init__( self.downsampler: nn.Module | None if add_downsample: if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -880,7 +591,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -927,7 +638,6 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -947,7 +657,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -959,7 +668,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -970,7 +679,7 @@ def __init__( ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -980,7 +689,6 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) ) @@ -991,7 +699,7 @@ def __init__( self.downsampler: nn.Module | None if add_downsample: if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1001,7 +709,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -1039,7 +747,6 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1050,11 +757,10 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() - self.resnet_1 = _ResnetBlock( + self.resnet_1 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1062,16 +768,15 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) - self.attention = _AttentionBlock( + self.attention = SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=in_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) - self.resnet_2 = _ResnetBlock( + self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1105,7 +810,6 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1119,12 +823,11 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() - self.resnet_1 = _ResnetBlock( + self.resnet_1 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1132,7 +835,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) - self.attention = _SpatialTransformer( + self.attention = SpatialTransformer( spatial_dims=spatial_dims, in_channels=in_channels, num_attention_heads=in_channels // num_head_channels, @@ -1142,10 +845,9 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) - self.resnet_2 = _ResnetBlock( + self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1203,7 +905,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1218,7 +920,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1228,9 +930,26 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) + else: self.upsampler = None @@ -1272,7 +991,6 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1288,7 +1006,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1301,7 +1018,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1311,13 +1028,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -1327,7 +1043,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1337,8 +1053,25 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -1385,7 +1118,6 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -1405,7 +1137,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -1419,7 +1150,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1429,7 +1160,7 @@ def __init__( ) ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -1439,7 +1170,6 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) ) @@ -1450,7 +1180,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1460,8 +1190,25 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -1504,7 +1251,6 @@ def get_down_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_attn: @@ -1519,7 +1265,6 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1536,7 +1281,6 @@ def get_down_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1564,7 +1308,6 @@ def get_mid_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_conditioning: @@ -1578,7 +1321,6 @@ def get_mid_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1589,7 +1331,6 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) @@ -1610,7 +1351,6 @@ def get_up_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_attn: @@ -1626,7 +1366,6 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1644,7 +1383,6 @@ def get_up_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1685,7 +1423,6 @@ class DiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -1706,7 +1443,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -1747,14 +1483,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.out_channels = out_channels @@ -1809,7 +1537,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1827,7 +1554,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1862,7 +1588,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1944,7 +1669,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) + h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: @@ -1961,6 +1686,63 @@ def forward( return output + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DiffusionModelUNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + # fix the upsample conv blocks which were renamed postconv + for k in new_state_dict: + if "postconv" in k: + old_name = k.replace("postconv", "conv") + new_state_dict[k] = old_state_dict[old_name] + self.load_state_dict(new_state_dict) + class DiffusionModelEncoder(nn.Module): """ diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 99975271da..74d15bc6bf 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -46,7 +46,6 @@ "resnet200", ] - resnet_params = { # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index 0949e307b9..294b121c94 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -17,9 +17,9 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution, Upsample +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample from monai.networks.blocks.spade_norm import SPADE -from monai.networks.nets.autoencoderkl import AttentionBlock, Encoder +from monai.networks.nets.autoencoderkl import Encoder from monai.utils import ensure_tuple_rep __all__ = ["SPADEAutoencoderKL"] @@ -195,7 +195,7 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -238,7 +238,7 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index bffc9c5465..e019d21c11 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -36,24 +36,19 @@ import torch from torch import nn -from monai.networks.blocks import Convolution +from monai.networks.blocks import Convolution, SpatialAttentionBlock from monai.networks.blocks.spade_norm import SPADE from monai.networks.nets.diffusion_model_unet import ( - _AttentionBlock, - _Downsample, - _ResnetBlock, - _SpatialTransformer, - _Upsample, + DiffusionUnetDownsample, + DiffusionUNetResnetBlock, + SpatialTransformer, + WrappedUpsample, get_down_block, get_mid_block, get_timestep_embedding, zero_module, ) -from monai.utils import ensure_tuple_rep, optional_import - -# To install xformers, use pip install xformers==0.0.16rc401 -xops, has_xformers = optional_import("xformers.ops") - +from monai.utils import ensure_tuple_rep __all__ = ["SPADEDiffusionModelUNet"] @@ -120,9 +115,17 @@ def __init__( self.upsample = self.downsample = None if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) @@ -252,7 +255,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -262,8 +265,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -308,7 +327,6 @@ class SPADEAttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer """ @@ -326,7 +344,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -351,13 +368,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -367,7 +383,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -377,8 +393,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -427,7 +459,6 @@ class SPADECrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer. """ @@ -448,7 +479,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -473,7 +503,7 @@ def __init__( ) ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -483,7 +513,6 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) ) @@ -493,7 +522,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -503,8 +532,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -549,7 +594,6 @@ def get_spade_up_block( label_nc: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> nn.Module: if with_attn: @@ -566,7 +610,6 @@ def get_spade_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, spade_intermediate_channels=spade_intermediate_channels, ) elif with_cross_attn: @@ -586,7 +629,6 @@ def get_spade_up_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, spade_intermediate_channels=spade_intermediate_channels, ) else: @@ -630,7 +672,6 @@ class SPADEDiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer """ @@ -652,7 +693,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -691,14 +731,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.out_channels = out_channels @@ -754,7 +786,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -771,7 +802,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) # up @@ -805,7 +835,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, ) @@ -890,7 +919,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) + h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index b742c12205..215e8d11a9 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -11,221 +11,14 @@ from __future__ import annotations -import math - import torch import torch.nn as nn -import torch.nn.functional as F -from monai.networks.blocks.mlp import MLPBlock -from monai.utils import optional_import +from monai.networks.blocks import TransformerBlock -xops, has_xformers = optional_import("xformers.ops") __all__ = ["DecoderOnlyTransformer"] -class _SABlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A self-attention block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: dropout ratio. Defaults to no dropout. - qkv_bias: bias term for the qkv linear layer. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.scale = 1.0 / math.sqrt(self.head_dim) - self.causal = causal - self.sequence_length = sequence_length - self.with_cross_attention = with_cross_attention - self.use_flash_attention = use_flash_attention - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - self.dropout_rate = dropout_rate - - if hidden_size % num_heads != 0: - raise ValueError("hidden size should be divisible by num_heads.") - - if causal and sequence_length is None: - raise ValueError("sequence_length is necessary for causal attention.") - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - # key, query, value projections - self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - - # regularization - self.drop_weights = nn.Dropout(dropout_rate) - self.drop_output = nn.Dropout(dropout_rate) - - # output projection - self.out_proj = nn.Linear(hidden_size, hidden_size) - - if causal and sequence_length is not None: - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "causal_mask", - torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), - ) - self.causal_mask: torch.Tensor - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - query = self.to_q(x) - - kv = context if context is not None else x - _, kv_t, _ = kv.size() - key = self.to_k(kv) - value = self.to_v(kv) - - query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) - key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - y: torch.Tensor - if self.use_flash_attention: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - y = xops.memory_efficient_attention( - query=query, - key=key, - value=value, - scale=self.scale, - p=self.dropout_rate, - attn_bias=xops.LowerTriangularMask() if self.causal else None, - ) - - else: - query = query.transpose(1, 2) # (b, nh, t, hs) - key = key.transpose(1, 2) # (b, nh, kv_t, hs) - value = value.transpose(1, 2) # (b, nh, kv_t, hs) - - # manual implementation of attention - query = query * self.scale - attention_scores = query @ key.transpose(-2, -1) - - if self.causal: - attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.drop_weights(attention_probs) - y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - - y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) - - y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side - - y = self.out_proj(y) - y = self.drop_output(y) - return y - - -class _TransformerBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A transformer block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - self.with_cross_attention = with_cross_attention - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm1 = nn.LayerNorm(hidden_size) - self.attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=causal, - sequence_length=sequence_length, - use_flash_attention=use_flash_attention, - ) - - if self.with_cross_attention: - self.norm2 = nn.LayerNorm(hidden_size) - self.cross_attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - with_cross_attention=with_cross_attention, - causal=False, - use_flash_attention=use_flash_attention, - ) - self.norm3 = nn.LayerNorm(hidden_size) - self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) - if self.with_cross_attention: - x = x + self.cross_attn(self.norm2(x), context=context) - x = x + self.mlp(self.norm3(x)) - return x - - class AbsolutePositionalEmbedding(nn.Module): """Absolute positional embedding. @@ -258,7 +51,6 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -270,7 +62,6 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, - use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -286,7 +77,7 @@ def __init__( self.blocks = nn.ModuleList( [ - _TransformerBlock( + TransformerBlock( hidden_size=attn_layers_dim, mlp_dim=attn_layers_dim * 4, num_heads=attn_layers_heads, @@ -295,7 +86,6 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, - use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] @@ -312,3 +102,56 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch x = block(x, context=context) logits: torch.Tensor = self.to_logits(x) return logits + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DecoderOnlyTransformer trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn.to_q.weight"], + old_state_dict[f"{block}.attn.to_k.weight"], + old_state_dict[f"{block}.attn.to_v.weight"], + ], + dim=0, + ) + + # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 + for k in old_state_dict: + if "norm2" in k: + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + if "norm3" in k: + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] + + self.load_state_dict(new_state_dict) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ecf237a2ff..6a97434215 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -987,6 +987,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): inputs=input_placeholder, enabled_precisions=convert_precision, device=target_device, + ir="torchscript", **kwargs, ) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 886103a0ab..84dd3ad1f6 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -527,7 +527,7 @@ def doc_images() -> str | None: @staticmethod def algo_hash() -> str | None: - return os.environ.get("MONAI_ALGO_HASH", "4403f94") + return os.environ.get("MONAI_ALGO_HASH", "e4cf5a1") @staticmethod def trace_transform() -> str | None: diff --git a/requirements-dev.txt b/requirements-dev.txt index b207b56b19..ce28d3ebe2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -46,13 +46,13 @@ pynrrd pre-commit pydicom h5py -nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine +nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 -filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 +filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 83f6cabc5e..6a577f763f 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -14,11 +14,17 @@ import unittest import torch +import torch.nn as nn import monai.networks.nets.attentionunet as att from tests.utils import skip_if_no_cuda, skip_if_quick +def get_net_parameters(net: nn.Module) -> int: + """Returns the total number of parameters in a Module.""" + return sum(param.numel() for param in net.parameters()) + + class TestAttentionUnet(unittest.TestCase): def test_attention_block(self): @@ -50,6 +56,20 @@ def test_attentionunet(self): self.assertEqual(output.shape[0], input.shape[0]) self.assertEqual(output.shape[1], 2) + def test_attentionunet_kernel_size(self): + args_dict = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, + "channels": (3, 4, 5), + "up_kernel_size": 5, + "strides": (1, 2), + } + model_a = att.AttentionUnet(**args_dict, kernel_size=5) + model_b = att.AttentionUnet(**args_dict, kernel_size=7) + self.assertEqual(get_net_parameters(model_a), 3534) + self.assertEqual(get_net_parameters(model_b), 5574) + @skip_if_no_cuda def test_attentionunet_gpu(self): for dims in [2, 3]: diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 3cc671a1d0..d15cb79084 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -11,20 +11,26 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import AutoencoderKL from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config tqdm, has_tqdm = optional_import("tqdm", name="tqdm") -einops, has_einops = optional_import("einops") +_, has_einops = optional_import("einops") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + CASES_NO_ATTENTION = [ [ { @@ -299,6 +305,33 @@ def test_shape_decode_convtranspose_and_checkpointing(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, True), + num_res_blocks=1, + norm_num_groups=4, + ).to(device) + + tmpdir = tempfile.mkdtemp() + key = "autoencoderkl_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "autoencoderkl_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 8f376a06d5..cfcadcfc4c 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -72,9 +72,9 @@ def test_export(self, key_in_ckpt, use_trace): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_default_value(self, key_in_ckpt, use_trace): diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index 605b3945bb..f84713fbe3 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -51,8 +51,8 @@ class TestGetBundleData(unittest.TestCase): def test_get_all_bundles_list(self, params): with skip_if_downloading_fails(): output = get_all_bundles_list(**params) - self.assertTrue(isinstance(output, list)) - self.assertTrue(isinstance(output[0], tuple)) + self.assertIsInstance(output, list) + self.assertIsInstance(output[0], tuple) self.assertTrue(len(output[0]) == 2) @parameterized.expand([TEST_CASE_1, TEST_CASE_5]) @@ -60,16 +60,17 @@ def test_get_all_bundles_list(self, params): def test_get_bundle_versions(self, params): with skip_if_downloading_fails(): output = get_bundle_versions(**params) - self.assertTrue(isinstance(output, dict)) - self.assertTrue("latest_version" in output and "all_versions" in output) - self.assertTrue("0.1.0" in output["all_versions"]) + self.assertIsInstance(output, dict) + self.assertIn("latest_version", output) + self.assertIn("all_versions", output) + self.assertIn("0.1.0", output["all_versions"]) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @skip_if_quick def test_get_bundle_info(self, params): with skip_if_downloading_fails(): output = get_bundle_info(**params) - self.assertTrue(isinstance(output, dict)) + self.assertIsInstance(output, dict) for key in ["id", "name", "size", "download_count", "browser_download_url"]: self.assertTrue(key in output) @@ -78,7 +79,7 @@ def test_get_bundle_info(self, params): def test_get_bundle_info_monaihosting(self, params): with skip_if_downloading_fails(): output = get_bundle_info(**params) - self.assertTrue(isinstance(output, dict)) + self.assertIsInstance(output, dict) for key in ["name", "browser_download_url"]: self.assertTrue(key in output) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 47034852ef..833a0ca1dc 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -91,9 +91,9 @@ def test_trt_export(self, convert_precision, input_shape, dynamic_batch): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) @unittest.skipUnless( @@ -129,9 +129,9 @@ def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) if __name__ == "__main__": diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 9a276b577f..1727fcdf53 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -138,11 +138,11 @@ def test_train_config(self, config_file): self.assertListEqual(trainer.check_properties(), []) # test read / write the properties dataset = trainer.train_dataset - self.assertTrue(isinstance(dataset, Dataset)) + self.assertIsInstance(dataset, Dataset) inferer = trainer.train_inferer - self.assertTrue(isinstance(inferer, SimpleInferer)) + self.assertIsInstance(inferer, SimpleInferer) # test optional properties get - self.assertTrue(trainer.train_key_metric is None) + self.assertIsNone(trainer.train_key_metric) trainer.train_dataset = deepcopy(dataset) trainer.train_inferer = deepcopy(inferer) # test optional properties set diff --git a/tests/test_clip_intensity_percentilesd.py b/tests/test_clip_intensity_percentilesd.py index fa727b6adb..ed4fc588cb 100644 --- a/tests/test_clip_intensity_percentilesd.py +++ b/tests/test_clip_intensity_percentilesd.py @@ -96,7 +96,7 @@ def test_channel_wise(self, p): for i, c in enumerate(im): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) - assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-3, atol=0) def test_ill_sharpness_factor(self): key = "img" diff --git a/tests/test_component_store.py b/tests/test_component_store.py index 424eceb3d1..7e7c6dd19d 100644 --- a/tests/test_component_store.py +++ b/tests/test_component_store.py @@ -48,17 +48,17 @@ def test_add2(self): self.cs.add("test_obj2", "Test object", test_obj2) self.assertEqual(len(self.cs), 2) - self.assertTrue("test_obj1" in self.cs) - self.assertTrue("test_obj2" in self.cs) + self.assertIn("test_obj1", self.cs) + self.assertIn("test_obj2", self.cs) def test_add_def(self): - self.assertFalse("test_func" in self.cs) + self.assertNotIn("test_func", self.cs) @self.cs.add_def("test_func", "Test function") def test_func(): return 123 - self.assertTrue("test_func" in self.cs) + self.assertIn("test_func", self.cs) self.assertEqual(len(self.cs), 1) self.assertEqual(list(self.cs), [("test_func", test_func)]) diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py index bbd5230f04..6e46cf2b1e 100644 --- a/tests/test_compute_ho_ver_maps.py +++ b/tests/test_compute_ho_ver_maps.py @@ -67,8 +67,8 @@ class ComputeHoVerMapsTests(unittest.TestCase): def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): input_image = in_type(mask) result = ComputeHoVerMaps(**arguments)(input_image) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32")) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32")) assert_allclose(result, hv_mask, type_test="tensor") diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py index 7b5ac0d9d7..0734e2e731 100644 --- a/tests/test_compute_ho_ver_maps_d.py +++ b/tests/test_compute_ho_ver_maps_d.py @@ -71,8 +71,8 @@ def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): for k in mask.keys(): input_image[k] = in_type(mask[k]) result = ComputeHoVerMapsd(keys="mask", **arguments)(input_image)[hv_key] - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32")) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32")) assert_allclose(result, hv_mask[hv_key], type_test="tensor") diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index a8b7f03e47..c407ab6ba6 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -70,22 +70,24 @@ def test_shape_reduction(self): mt = mt_fn(reduction="mean") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 1) + self.assertEqual(len(out_tensor.shape), 1) mt = mt_fn(reduction="sum") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 0) + self.assertEqual(len(out_tensor.shape), 0) mt = mt_fn(reduction="sum") # test reduction arg overriding mt(in_tensor, in_tensor) out_tensor = mt.aggregate(reduction="mean_channel") - self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) + self.assertEqual(len(out_tensor.shape), 1) + self.assertEqual(out_tensor.shape[0], batch) mt = mt_fn(reduction="sum_channel") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) + self.assertEqual(len(out_tensor.shape), 1) + self.assertEqual(out_tensor.shape[0], batch) def test_compare_numpy(self): set_determinism(seed=123) diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 64c5d6e255..564ddf5c1f 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -30,7 +30,7 @@ def test_tensor_values(self): "img2": torch.tensor([[0, 1], [1, 2]], device=device), } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) + self.assertIn("cat_img", result) result["cat_img"] += 1 assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device)) @@ -42,8 +42,8 @@ def test_metatensor_values(self): "img2": MetaTensor([[0, 1], [1, 2]], device=device), } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) - self.assertTrue(isinstance(result["cat_img"], MetaTensor)) + self.assertIn("cat_img", result) + self.assertIsInstance(result["cat_img"], MetaTensor) self.assertEqual(result["img1"].meta, result["cat_img"].meta) result["cat_img"] += 1 assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) @@ -52,7 +52,7 @@ def test_metatensor_values(self): def test_numpy_values(self): input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])} result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) + self.assertIn("cat_img", result) result["cat_img"] += 1 np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]])) np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2, 3], [1, 2], [2, 3]])) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cc890a0522..cf1edc8f08 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -185,7 +185,7 @@ def test_function(self, config): if id in ("compute", "cls_compute"): parser[f"{id}#_mode_"] = "callable" func = parser.get_parsed_content(id=id) - self.assertTrue(id in parser.ref_resolver.resolved_content) + self.assertIn(id, parser.ref_resolver.resolved_content) if id == "error_func": with self.assertRaises(TypeError): func(1, 2) diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 07dfa2e49b..05ceb69fa3 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -12,13 +12,16 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets.controlnet import ControlNet +from monai.utils import optional_import +_, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ { @@ -147,6 +150,7 @@ class TestControlNet(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_output_shape): input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) @@ -160,6 +164,7 @@ def test_shape_unconditioned_models(self, input_param, expected_output_shape): self.assertEqual(result[1].shape, expected_output_shape) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_output_shape): input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 1f675537dc..96e707acb5 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -29,6 +30,8 @@ from monai.utils import optional_import _, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") + CNDM_TEST_CASES = [ [ @@ -443,6 +446,7 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -464,6 +468,7 @@ def test_call(self, model_params, controlnet_params, input_shape): self.assertEqual(sample.shape, input_shape) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -489,6 +494,7 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -514,6 +520,7 @@ def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -539,6 +546,7 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 @@ -568,6 +576,7 @@ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape) self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -604,6 +613,7 @@ def test_normal_cdf(self): torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() @@ -642,6 +652,7 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, ae_model_type, @@ -708,6 +719,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape( self, ae_model_type, @@ -770,6 +782,7 @@ def test_sample_shape( self.assertEqual(sample.shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates( self, ae_model_type, @@ -837,6 +850,7 @@ def test_sample_intermediates( self.assertEqual(intermediates[0].shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihoods( self, ae_model_type, @@ -904,6 +918,7 @@ def test_get_likelihoods( self.assertEqual(intermediates[0].shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_resample_likelihoods( self, ae_model_type, @@ -973,6 +988,7 @@ def test_resample_likelihoods( self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_conditioned_concat( self, ae_model_type, @@ -1053,6 +1069,7 @@ def test_prediction_shape_conditioned_concat( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_conditioned_concat( self, ae_model_type, @@ -1128,6 +1145,7 @@ def test_sample_shape_conditioned_concat( self.assertEqual(sample.shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_different_latents( self, ae_model_type, @@ -1203,6 +1221,7 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( spatial_dims=2, diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py new file mode 100644 index 0000000000..4ab0ab1823 --- /dev/null +++ b/tests/test_crossattention.py @@ -0,0 +1,131 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.crossattention import CrossAttentionBlock +from monai.networks.layers.factories import RelPosEmbedding +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_CABLOCK = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) + + +class TestResBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CABLOCK) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = CrossAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + @skipUnless(has_einops, "Requires einops") + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") + def test_context_input(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) + + @skipUnless(has_einops, "Requires einops") + def test_context_wrong_input_size(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + with self.assertRaises(RuntimeError): + block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) + + @skipUnless(has_einops, "Requires einops") + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate + ) + no_matrix_acess_blk(torch.randn(input_shape)) + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py index d2dcc6aa5f..3c5703a34c 100644 --- a/tests/test_cucim_dict_transform.py +++ b/tests/test_cucim_dict_transform.py @@ -80,8 +80,8 @@ class TestCuCIMDict(unittest.TestCase): def test_tramsforms_numpy_single(self, params, input, expected): input = {"image": input} output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -98,8 +98,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected): input = {"image": input[cp.newaxis, ...]} expected = expected[cp.newaxis, ...] output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -116,8 +116,8 @@ def test_tramsforms_cupy_single(self, params, input, expected): input = {"image": cp.asarray(input)} expected = cp.asarray(expected) output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -134,8 +134,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected): input = {"image": cp.asarray(input)[cp.newaxis, ...]} expected = cp.asarray(expected)[cp.newaxis, ...] output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py index 5f16c11589..162e16b52a 100644 --- a/tests/test_cucim_transform.py +++ b/tests/test_cucim_transform.py @@ -79,8 +79,8 @@ class TestCuCIM(unittest.TestCase): ) def test_tramsforms_numpy_single(self, params, input, expected): output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -97,8 +97,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected): input = input[cp.newaxis, ...] expected = expected[cp.newaxis, ...] output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -115,8 +115,8 @@ def test_tramsforms_cupy_single(self, params, input, expected): input = cp.asarray(input) expected = cp.asarray(expected) output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -133,8 +133,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected): input = cp.asarray(input)[cp.newaxis, ...] expected = cp.asarray(expected)[cp.newaxis, ...] output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index e2efefeb77..f9c2b5ac53 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -147,7 +147,7 @@ def test_value_error(self, arguments, image, method): elif method == "__call__": self.assertRaises(ValueError, DetectEnvelope(**arguments), image) else: - raise ValueError("Expected raising method invalid. Should be __init__ or __call__.") + self.fail("Expected raising method invalid. Should be __init__ or __call__.") @SkipIfModule("torch.fft") diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index ecd4855385..7f37025d3c 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -22,6 +23,7 @@ from monai.utils import optional_import _, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") TEST_CASES = [ [ @@ -55,6 +57,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -70,6 +73,7 @@ def test_call(self, model_params, input_shape): self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -85,6 +89,7 @@ def test_sample_intermediates(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -100,6 +105,7 @@ def test_ddpm_sampler(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -115,6 +121,7 @@ def test_ddim_sampler(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 @@ -138,6 +145,7 @@ def test_sampler_conditioned(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -166,6 +174,7 @@ def test_normal_cdf(self): torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat(self, model_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() @@ -196,6 +205,7 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call_conditioned_concat(self, model_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index d40a31a1da..7f764d85de 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -11,13 +11,21 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import DiffusionModelUNet +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ @@ -286,12 +294,14 @@ class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, 1, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_timestep_with_wrong_shape(self): net = DiffusionModelUNet( spatial_dims=2, @@ -306,6 +316,7 @@ def test_timestep_with_wrong_shape(self): with eval_mode(net): net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -359,6 +370,7 @@ def test_num_res_blocks_with_different_length_channels(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=2, @@ -396,6 +408,7 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_context_with_conditioning_none(self): net = DiffusionModelUNet( spatial_dims=2, @@ -417,6 +430,7 @@ def test_context_with_conditioning_none(self): context=torch.rand((1, 1, 3)), ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models_class_conditioning(self): net = DiffusionModelUNet( spatial_dims=2, @@ -437,6 +451,7 @@ def test_shape_conditioned_models_class_conditioning(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models_no_class_labels(self): net = DiffusionModelUNet( spatial_dims=2, @@ -453,6 +468,7 @@ def test_conditioned_models_no_class_labels(self): with self.assertRaises(ValueError): net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + @skipUnless(has_einops, "Requires einops") def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): DiffusionModelUNet( @@ -468,6 +484,7 @@ def test_model_channels_not_same_size_of_attention_levels(self): ) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_conditioned_2d_models_shape(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): @@ -477,12 +494,14 @@ def test_conditioned_2d_models_shape(self, input_param): class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -499,6 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=3, @@ -527,9 +547,39 @@ def test_wrong_dropout(self, input_param): _ = DiffusionModelUNet(**input_param) @parameterized.expand(DROPOUT_OK) + @skipUnless(has_einops, "Requires einops") def test_right_dropout(self, input_param): _ = DiffusionModelUNet(**input_param) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + cross_attention_dim=3, + transformer_num_layers=1, + norm_num_groups=8, + ) + + tmpdir = tempfile.mkdtemp() + key = "diffusion_model_unet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "diffusion_model_unet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 09aa1f04b5..fe543347de 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -33,8 +33,8 @@ def test_array_input(self): keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu" )({"data": test_data})["data"] if dtype == "NUMPY": - self.assertTrue(result.dtype == np.float32) - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertEqual(result.dtype, np.float32) + self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray) assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) @@ -45,7 +45,7 @@ def test_single_input(self): for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray) if isinstance(test_data, bool): self.assertFalse(result) else: @@ -56,11 +56,11 @@ def test_string(self): for dtype in ("tensor", "numpy"): # string input result = EnsureTyped(keys="data", data_type=dtype)({"data": "test_string"})["data"] - self.assertTrue(isinstance(result, str)) + self.assertIsInstance(result, str) self.assertEqual(result, "test_string") # numpy array of string result = EnsureTyped(keys="data", data_type=dtype)({"data": np.array(["test_string"])})["data"] - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertEqual(result[0], "test_string") def test_list_tuple(self): @@ -68,15 +68,15 @@ def test_list_tuple(self): result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False, track_meta=True)( {"data": [[1, 2], [3, 4]]} )["data"] - self.assertTrue(isinstance(result, list)) - self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, list) + self.assertIsInstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray) assert_allclose(result[1][0], torch.as_tensor(3), type_test=False) # tuple of numpy arrays result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)( {"data": (np.array([1, 2]), np.array([3, 4]))} )["data"] - self.assertTrue(isinstance(result, tuple)) - self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, tuple) + self.assertIsInstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray) assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False) def test_dict(self): @@ -92,19 +92,19 @@ def test_dict(self): ) for key in ("data", "label"): result = trans[key] - self.assertTrue(isinstance(result, dict)) - self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) - self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, dict) + self.assertIsInstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray) + self.assertIsInstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray) self.assertEqual(result["meta"]["path"], "temp/test") self.assertEqual(result["extra"], None) assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) if dtype == "numpy": - self.assertTrue(trans["data"]["img"].dtype == np.float32) - self.assertTrue(trans["label"]["img"].dtype == np.int8) + self.assertEqual(trans["data"]["img"].dtype, np.float32) + self.assertEqual(trans["label"]["img"].dtype, np.int8) else: - self.assertTrue(trans["data"]["img"].dtype == torch.float32) - self.assertTrue(trans["label"]["img"].dtype == torch.int8) + self.assertEqual(trans["data"]["img"].dtype, torch.float32) + self.assertEqual(trans["label"]["img"].dtype, torch.int8) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 277f387051..1df6d34056 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -78,7 +78,7 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): def test_meta_dict(self): xform = Flipd("image", [0, 1]) res = xform({"image": torch.zeros(1, 3, 4)}) - self.assertTrue(res["image"].applied_operations == res["image_transforms"]) + self.assertEqual(res["image"].applied_operations, res["image_transforms"]) if __name__ == "__main__": diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py index 1bea4ed1b5..7be8e576bf 100644 --- a/tests/test_freeze_layers.py +++ b/tests/test_freeze_layers.py @@ -40,9 +40,9 @@ def test_freeze_vars(self, device): for name, param in model.named_parameters(): if "class_layer" in name: - self.assertEqual(param.requires_grad, False) + self.assertFalse(param.requires_grad) else: - self.assertEqual(param.requires_grad, True) + self.assertTrue(param.requires_grad) @parameterized.expand(TEST_CASES) def test_exclude_vars(self, device): @@ -53,9 +53,9 @@ def test_exclude_vars(self, device): for name, param in model.named_parameters(): if "class_layer" in name: - self.assertEqual(param.requires_grad, True) + self.assertTrue(param.requires_grad) else: - self.assertEqual(param.requires_grad, False) + self.assertFalse(param.requires_grad) if __name__ == "__main__": diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 7499507129..5738f4a089 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -184,7 +184,7 @@ def test_differentiability(self): generalized_dice_loss = GeneralizedDiceLoss() loss = generalized_dice_loss(prediction, target) - self.assertNotEqual(loss.grad_fn, None) + self.assertIsNotNone(loss.grad_fn) def test_batch(self): prediction = torch.zeros(2, 3, 3, 3) @@ -194,7 +194,7 @@ def test_batch(self): generalized_dice_loss = GeneralizedDiceLoss(batch=True) loss = generalized_dice_loss(prediction, target) - self.assertNotEqual(loss.grad_fn, None) + self.assertIsNotNone(loss.grad_fn) def test_script(self): loss = GeneralizedDiceLoss() diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index ab9e69cd31..e9e1d8eca6 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -20,14 +20,14 @@ class TestGetVersion(unittest.TestCase): def test_default(self): output = get_package_version("42foobarnoexist") - self.assertTrue("UNKNOWN" in output) + self.assertIn("UNKNOWN", output) output = get_package_version("numpy") - self.assertFalse("UNKNOWN" in output) + self.assertNotIn("UNKNOWN", output) def test_msg(self): output = get_package_version("42foobarnoexist", "test") - self.assertTrue("test" in output) + self.assertIn("test", output) if __name__ == "__main__": diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 4b324eda1a..56af123548 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -124,11 +124,11 @@ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta) self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): assert_allclose(output_patch, expected_patch, type_test=False) - self.assertTrue(isinstance(output_patch, MetaTensor)) - self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertIsInstance(output_patch, MetaTensor) + self.assertEqual(output_patch.meta["location"], expected_patch_meta["location"]) self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) if "path" in expected_meta[0]: - self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) + self.assertEqual(output_patch.meta["path"], expected_patch_meta["path"]) if __name__ == "__main__": diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index f876cff2a3..52da5c179b 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -76,9 +76,9 @@ def _update_metric(engine): if has_key_word.match(line): content_count += 1 if epoch_log is True: - self.assertTrue(content_count == max_epochs) + self.assertEqual(content_count, max_epochs) else: - self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter + self.assertEqual(content_count, 2) # 2 = len([1, 2]) from event_filter @parameterized.expand([[True], [get_event_filter([1, 3])]]) def test_loss_print(self, iteration_log): @@ -116,9 +116,9 @@ def _train_func(engine, batch): if has_key_word.match(line): content_count += 1 if iteration_log is True: - self.assertTrue(content_count == num_iters * max_epochs) + self.assertEqual(content_count, num_iters * max_epochs) else: - self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter + self.assertEqual(content_count, 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() @@ -150,7 +150,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_loss_file(self): key_to_handler = "test_logging" @@ -184,7 +184,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_exception(self): # set up engine @@ -239,7 +239,7 @@ def _update_metric(engine): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_default_logger(self): log_stream = StringIO() @@ -274,7 +274,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index c2e0fb55b7..60aaef05bf 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -135,9 +135,8 @@ def test_scripts_fold(self): command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file] completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True) output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output - print(output) - self.assertTrue(expected_condition in output) + self.assertIn(expected_condition, output) command_run_workflow = cmd + [ "run_workflow", "--run_id", @@ -149,8 +148,7 @@ def test_scripts_fold(self): ] completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True) output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output - print(output) - self.assertTrue(expected_condition in output) + self.assertIn(expected_condition, output) # test missing meta file self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file])) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index f33b5c67eb..bf3972e6bd 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -133,7 +133,7 @@ def test_collation(self, _, transform, collate_fn, ndim): d = decollate_batch(item) self.assertTrue(len(d) <= self.batch_size) for b in d: - self.assertTrue(isinstance(b["image"], MetaTensor)) + self.assertIsInstance(b["image"], MetaTensor) np.testing.assert_array_equal( b["image"].applied_operations[-1]["orig_size"], b["label"].applied_operations[-1]["orig_size"] ) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index c32a3af643..f6e8fc40e7 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -134,7 +134,7 @@ def test_invert(self): # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 - self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") + self.assertLess((reverted.size - n_good), 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 4ab803bb6f..065ebafd95 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -19,7 +20,9 @@ from monai.inferers import LatentDiffusionInferer from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet from monai.networks.schedulers import DDPMScheduler +from monai.utils import optional_import +_, has_einops = optional_import("einops") TEST_CASES = [ [ "AutoencoderKL", @@ -313,6 +316,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -360,6 +364,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -404,6 +409,7 @@ def test_sample_shape( self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -458,6 +464,7 @@ def test_sample_intermediates( self.assertEqual(intermediates[0].shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -510,6 +517,7 @@ def test_get_likelihoods( self.assertEqual(intermediates[0].shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_resample_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -564,6 +572,7 @@ def test_resample_likelihoods( self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -629,6 +638,7 @@ def test_prediction_shape_conditioned_concat( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -689,6 +699,7 @@ def test_sample_shape_conditioned_concat( self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -745,6 +756,7 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( spatial_dims=2, diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 699ed70059..914240c705 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -190,7 +190,7 @@ def test_correct(self, input_p, expected_shape, track_meta): self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) self.assertEqual(r.meta["space"], "RAS") - self.assertTrue("qform_code" not in r.meta) + self.assertNotIn("qform_code", r.meta) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor) diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 63422761ca..cbc730e1bb 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -48,7 +48,7 @@ def test_load_spacingd(self, filename): ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") - self.assertTrue(t2 >= t1) + self.assertGreaterEqual(t2, t1) np.testing.assert_allclose(res_dict["image"].affine, ref.affine) np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05) @@ -68,7 +68,7 @@ def test_load_spacingd_rotate(self, filename): ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") - self.assertTrue(t2 >= t1) + self.assertGreaterEqual(t2, t1) np.testing.assert_allclose(res_dict["image"].affine, ref.affine) if "anatomical" not in filename: np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index d40b7eaa8c..75560b4ac4 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -56,7 +56,7 @@ def test_default(self): def test_str_enum(self): output = look_up_option("C", {"A", "B"}, default=None) - self.assertEqual(output, None) + self.assertIsNone(output) self.assertEqual(list(_CaseStrEnum), ["A", "B"]) self.assertEqual(_CaseStrEnum.MODE_A, "A") self.assertEqual(str(_CaseStrEnum.MODE_A), "A") diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index e513025e69..e54bb523e4 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -78,7 +78,7 @@ def test_samples(self): fig, mat = matshow3d( [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False ) - self.assertTrue(mat.dtype == np.float32) + self.assertEqual(mat.dtype, np.float32) with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/matshow3d_patch_test.png" diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py index 516388afce..bdfdf24f9f 100644 --- a/tests/test_median_filter.py +++ b/tests/test_median_filter.py @@ -21,13 +21,13 @@ class MedianFilterTestCase(unittest.TestCase): + @parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d def test_3d(self, input_tensor, radius): filter = MedianFilter(radius).to(torch.device("cpu:0")) expected = input_tensor.numpy() output = filter(input_tensor).cpu().numpy() - np.testing.assert_allclose(output, expected, rtol=1e-5) def test_3d_radii(self): diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 1db632c144..c1b21e9373 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -41,7 +41,7 @@ def _test_dataset(dataset): self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac)) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) - self.assertTrue(isinstance(dataset[0]["image"], MetaTensor)) + self.assertIsInstance(dataset[0]["image"], MetaTensor) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) with skip_if_downloading_fails(): diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py index 95764a0c89..890734391f 100644 --- a/tests/test_meta_affine.py +++ b/tests/test_meta_affine.py @@ -160,7 +160,7 @@ def test_linear_consistent(self, xform_cls, input_dict, atol): diff = np.abs(itk.GetArrayFromImage(ref_2) - itk.GetArrayFromImage(expected)) avg_diff = np.mean(diff) - self.assertTrue(avg_diff < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") + self.assertLess(avg_diff, atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") @parameterized.expand(TEST_CASES_DICT) def test_linear_consistent_dict(self, xform_cls, input_dict, atol): @@ -175,7 +175,7 @@ def test_linear_consistent_dict(self, xform_cls, input_dict, atol): diff = {k: np.abs(itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys} avg_diff = {k: np.mean(diff[k]) for k in keys} for k in keys: - self.assertTrue(avg_diff[k] < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") + self.assertLess(avg_diff[k], atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") if __name__ == "__main__": diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 1e0f188b63..f31a07eba4 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -222,9 +222,9 @@ def test_stack(self, device, dtype): def test_get_set_meta_fns(self): set_track_meta(False) - self.assertEqual(get_track_meta(), False) + self.assertFalse(get_track_meta()) set_track_meta(True) - self.assertEqual(get_track_meta(), True) + self.assertTrue(get_track_meta()) @parameterized.expand(TEST_DEVICES) def test_torchscript(self, device): diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 6af3d09fb2..2ac73a8149 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -142,7 +142,7 @@ def test_load_ckpt(self, input_args, expected_name, expected_val): def test_unique(self): # model ids are unique keys = sorted(m["id"] for m in MODEL_DESC) - self.assertTrue(keys == sorted(set(keys))) + self.assertEqual(keys, sorted(set(keys))) def test_search(self): self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2) diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index b7bf2fbb11..7c4969e283 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -165,7 +165,7 @@ def test_different_transforms(self): im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0] im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0] l2 = ((im1 - im2) ** 2).sum() ** 0.5 - self.assertTrue(l2 > 1) + self.assertGreater(l2, 1) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 950058a9e9..eb8ebd06c5 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -240,7 +240,7 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): resampler.lazy = False if input_param.get("cache_grid", False): - self.assertTrue(g.rand_affine._cached_grid is not None) + self.assertIsNotNone(g.rand_affine._cached_grid) for key in res: if isinstance(key, str) and key.endswith("_transforms"): continue diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py index 333a9ecba5..328f46b7ee 100644 --- a/tests/test_rand_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -39,7 +39,7 @@ def test_output_shape(self, class_args, img_shape): img = p(np.random.rand(*img_shape)) output = bias_field(img) np.testing.assert_equal(output.shape, img_shape) - self.assertTrue(output.dtype in (np.float32, torch.float32)) + self.assertIn(output.dtype, (np.float32, torch.float32)) img_zero = np.zeros([*img_shape]) output_zero = bias_field(img_zero) diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 1524442f61..a1414df0ac 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -154,7 +154,7 @@ def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, e crop = RandWeightedCropd(**init_params) crop.set_random_state(10) result = crop(input_data) - self.assertTrue(len(result) == init_params["num_samples"]) + self.assertEqual(len(result), init_params["num_samples"]) _len = len(tuple(input_data.keys())) self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py index 1815000777..48d3b59a17 100644 --- a/tests/test_recon_net_utils.py +++ b/tests/test_recon_net_utils.py @@ -64,7 +64,7 @@ def test_reshape_channel_complex(self, test_data): def test_complex_normalize(self, test_data): result, mean, std = complex_normalize(test_data) result = result * std + mean - self.assertTrue((((result - test_data) ** 2).mean() ** 0.5).item() < 1e-5) + self.assertLess((((result - test_data) ** 2).mean() ** 0.5).item(), 1e-5) @parameterized.expand(TEST_PAD) def test_pad(self, test_data): diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index e8f82eb0c2..1fb81689e6 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -99,7 +99,7 @@ def forward(self, x): # backward pass loss_val.backward() optimizer.step() - self.assertTrue(init_loss > loss_val, "loss did not decrease") + self.assertGreater(init_loss, loss_val, "loss did not decrease") if __name__ == "__main__": diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 449edba4bf..5d34a32d8d 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -198,6 +198,14 @@ [model, *TEST_CASE_1] for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] ] +CASE_EXTRACT_FEATURES = [ + ( + {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + [1, 1, 64, 64, 64], + ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), + ) +] + CASE_EXTRACT_FEATURES = [ ( @@ -228,7 +236,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): if input_param.get("feed_forward", True): self.assertEqual(result.shape, expected_shape) else: - self.assertTrue(result.shape in expected_shape) + self.assertIn(result.shape, expected_shape) @parameterized.expand(PRETRAINED_TEST_CASES) @skip_if_quick diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d52cc71e55..d069d6aa30 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -62,6 +62,27 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 @@ -83,6 +104,40 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + def test_number_of_parameters(self): + + def count_sablock_params(*args, **kwargs): + """Count the number of parameters in a SABlock.""" + sablock = SABlock(*args, **kwargs) + return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) + + hidden_size = 128 + num_heads = 8 + default_dim_head = hidden_size // num_heads + + # Default dim_head is hidden_size // num_heads + nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + nparams_like_default = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head + ) + self.assertEqual(nparams_default, nparams_like_default) + + # Increasing dim_head should increase the number of parameters + nparams_custom_large = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + ) + self.assertGreater(nparams_custom_large, nparams_default) + + # Decreasing dim_head should decrease the number of parameters + nparams_custom_small = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + ) + self.assertGreater(nparams_default, nparams_custom_small) + + # Increasing the number of heads with the default behaviour should not change the number of params. + nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) + self.assertEqual(nparams_default, nparams_default_more_heads) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py index 3d995a60c9..a0d7cf5a8b 100644 --- a/tests/test_sobel_gradient.py +++ b/tests/test_sobel_gradient.py @@ -164,8 +164,8 @@ def test_sobel_gradients(self, image, arguments, expected_grad): ) def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradients(**arguments) - self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) - self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype) assert_allclose(sobel.kernel_diff, expected_kernels[0]) assert_allclose(sobel.kernel_smooth, expected_kernels[1]) diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py index 7499a0410b..03524823a5 100644 --- a/tests/test_sobel_gradientd.py +++ b/tests/test_sobel_gradientd.py @@ -187,8 +187,8 @@ def test_sobel_gradients(self, image_dict, arguments, expected_grad): ) def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradientsd(**arguments) - self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) - self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype) assert_allclose(sobel.kernel_diff, expected_kernels[0]) assert_allclose(sobel.kernel_smooth, expected_kernels[1]) diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py index 113e58ed89..481705f56f 100644 --- a/tests/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -12,13 +12,16 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import SPADEDiffusionModelUNet +from monai.utils import optional_import +einops, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ { @@ -262,6 +265,7 @@ class TestSPADEDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -272,6 +276,7 @@ def test_shape_unconditioned_models(self, input_param): ) self.assertEqual(result.shape, (1, 1, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_timestep_with_wrong_shape(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -289,6 +294,7 @@ def test_timestep_with_wrong_shape(self): torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) ) + @skipUnless(has_einops, "Requires einops") def test_label_with_wrong_shape(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -304,6 +310,7 @@ def test_label_with_wrong_shape(self): with eval_mode(net): net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -363,6 +370,7 @@ def test_num_res_blocks_with_different_length_channels(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -387,6 +395,7 @@ def test_shape_conditioned_models(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_with_conditioning_cross_attention_dim_none(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( @@ -403,6 +412,7 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_context_with_conditioning_none(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -426,6 +436,7 @@ def test_context_with_conditioning_none(self): context=torch.rand((1, 1, 3)), ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models_class_conditioning(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -448,6 +459,7 @@ def test_shape_conditioned_models_class_conditioning(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models_no_class_labels(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -485,6 +497,7 @@ def test_model_channels_not_same_size_of_attention_levels(self): ) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_conditioned_2d_models_shape(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -499,6 +512,7 @@ def test_conditioned_2d_models_shape(self, input_param): class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -509,6 +523,7 @@ def test_shape_unconditioned_models(self, input_param): ) self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -530,6 +545,7 @@ def test_shape_with_different_in_channel_out_channel(self): ) self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = SPADEDiffusionModelUNet( spatial_dims=3, diff --git a/tests/test_spatialattention.py b/tests/test_spatialattention.py new file mode 100644 index 0000000000..70b78263c5 --- /dev/null +++ b/tests/test_spatialattention.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES = [ + [ + {"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6}, + (1, 128, 32, 32), + (1, 128, 32, 32), + ], + [ + {"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6}, + (1, 16, 8, 8, 8), + (1, 16, 8, 8, 8), + ], +] + + +class TestBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SpatialAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 9551dec703..568461748b 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -62,7 +62,7 @@ def test_container(self): self.assertTrue(con.is_alive) self.assertIsNotNone(con.status()) - self.assertTrue(len(con.status_dict) > 0) + self.assertGreater(len(con.status_dict), 0) con.join() diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 5a1754e7c5..38400f0d3f 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -62,8 +62,8 @@ def test_numpy_input_dtype(self): test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy(np.uint8)(test_data) - self.assertTrue(result.dtype == cp.uint8) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.uint8) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -72,8 +72,8 @@ def test_tensor_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -83,8 +83,8 @@ def test_tensor_cuda_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -95,8 +95,8 @@ def test_tensor_cuda_input_dtype(self): self.assertFalse(test_data.is_contiguous()) result = ToCupy(dtype="float32")(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index f92b7c0075..f4e5f80a29 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -32,7 +32,7 @@ def test_cupy_input(self): test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToNumpy()(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data.get(), type_test=False) @@ -41,8 +41,8 @@ def test_numpy_input(self): test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToNumpy(dtype="float32")(test_data) - self.assertTrue(isinstance(result, np.ndarray)) - self.assertTrue(result.dtype == np.float32) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.dtype, np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -51,7 +51,7 @@ def test_tensor_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToNumpy(dtype=torch.uint8)(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -61,7 +61,7 @@ def test_tensor_cuda_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToNumpy()(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -77,7 +77,7 @@ def test_list_tuple(self): def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: result = ToNumpy(dtype=np.uint8)(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) assert_allclose(result, np.asarray(test_data), type_test=False) self.assertEqual(result.ndim, 0) diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index 322cce1161..9cc19db62c 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -195,8 +195,8 @@ def test_get_module(self): mod = look_up_named_module("model.1.submodule.1.submodule.1.submodule.0.conv", net) self.assertTrue(str(mod).startswith("Conv2d")) self.assertIsInstance(set_named_module(net, "model", torch.nn.Identity()).model, torch.nn.Identity) - self.assertEqual(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net), None) - self.assertEqual(look_up_named_module("test attribute", net), None) + self.assertIsNone(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net)) + self.assertIsNone(look_up_named_module("test attribute", net)) if __name__ == "__main__": diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index dd139053e3..6a499b2dd9 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -33,12 +33,12 @@ def test_default(self): expected_key = "_transforms" a = _TraceTest() for x in a.transform_info_keys(): - self.assertTrue(x in a.get_transform_info()) + self.assertIn(x, a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"} data = a(data) # adds to the stack - self.assertTrue(isinstance(data[expected_key], list)) + self.assertIsInstance(data[expected_key], list) self.assertEqual(data[expected_key][0]["class"], "_TraceTest") data = a(data) # adds to the stack diff --git a/tests/test_transformer.py b/tests/test_transformer.py index ea6ebdf50f..b371809d47 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -11,15 +11,22 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import numpy as np import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import DecoderOnlyTransformer +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config +_, has_einops = optional_import("einops") TEST_CASES = [] for dropout_rate in np.linspace(0, 1, 2): for attention_layer_dim in [360, 480, 600, 768]: @@ -40,12 +47,14 @@ class TestDecoderOnlyTransformer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_unconditioned_models(self, input_param): net = DecoderOnlyTransformer(**input_param) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models(self, input_param): net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) with eval_mode(net): @@ -57,7 +66,9 @@ def test_attention_dim_not_multiple_of_heads(self): num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 ) + @skipUnless(has_einops, "Requires einops") def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): DecoderOnlyTransformer( num_tokens=10, @@ -68,6 +79,31 @@ def test_dropout_rate_negative(self): embedding_dropout_rate=-1, ) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + embedding_dropout_rate=0, + ) + + tmpdir = tempfile.mkdtemp() + key = "decoder_only_transformer_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "decoder_only_transformer_monai_generative_weights.pt" + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 5a8dbba83c..a850cc6f74 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import numpy as np import torch @@ -19,28 +20,33 @@ from monai.networks import eval_mode from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import optional_import +einops, has_einops = optional_import("einops") TEST_CASE_TRANSFORMERBLOCK = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 8, 12]: for mlp_dim in [1024, 3072]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_TRANSFORMERBLOCK.append(test_case) + for cross_attention in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "mlp_dim": mlp_dim, + "dropout_rate": dropout_rate, + "with_cross_attention": cross_attention, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_TRANSFORMERBLOCK.append(test_case) class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = TransformerBlock(**input_param) with eval_mode(net): @@ -54,6 +60,7 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 1a511d287b..36b715f588 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -12,14 +12,17 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.inferers import VQVAETransformerInferer from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils import optional_import from monai.utils.ordering import Ordering, OrderingType +einops, has_einops = optional_import("einops") TEST_CASES = [ [ { @@ -78,6 +81,7 @@ class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -98,6 +102,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, logits_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_shorter_sequence( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -121,7 +126,9 @@ def test_prediction_shape_shorter_sequence( cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) self.assertEqual(prediction.shape, cropped_logits_shape) + @skipUnless(has_einops, "Requires einops") def test_sample(self): + stage_1 = VQVAE( spatial_dims=2, in_channels=1, @@ -163,6 +170,7 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + @skipUnless(has_einops, "Requires einops") def test_sample_shorter_sequence(self): stage_1 = VQVAE( spatial_dims=2, @@ -206,6 +214,7 @@ def test_sample_shorter_sequence(self): self.assertEqual(sample.shape, (2, 1, 8, 8)) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -228,6 +237,7 @@ def test_get_likelihood( self.assertEqual(likelihood.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood_shorter_sequence( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -253,6 +263,7 @@ def test_get_likelihood_shorter_sequence( self.assertEqual(likelihood.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood_resampling( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): diff --git a/tests/test_warp.py b/tests/test_warp.py index bac595224f..55f40764c3 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -124,7 +124,7 @@ def test_itk_benchmark(self): relative_diff = np.mean( np.divide(monai_result - itk_result, itk_result, out=np.zeros_like(itk_result), where=(itk_result != 0)) ) - self.assertTrue(relative_diff < 0.01) + self.assertLess(relative_diff, 0.01) @parameterized.expand(TEST_CASES, skip_on_empty=True) def test_resample(self, input_param, input_data, expected_val): diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index a570c787ba..318331e5f7 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -138,6 +138,21 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" + }, + "decoder_only_transformer_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth", + "hash_type": "sha256", + "hash_val": "f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d" + }, + "diffusion_model_unet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth", + "hash_type": "sha256", + "hash_val": "0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee" + }, + "autoencoderkl_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", + "hash_type": "sha256", + "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" } }, "configs": { From c54bf3c18cd723e742b3809db26b5bb47d83ad57 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 13 May 2024 14:29:43 +0100 Subject: [PATCH 16/32] Tidy up init (#7755) Part of #7227 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Mark Graham --- monai/networks/nets/patchgan_discriminator.py | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py index 3b089616ce..74da917694 100644 --- a/monai/networks/nets/patchgan_discriminator.py +++ b/monai/networks/nets/patchgan_discriminator.py @@ -18,6 +18,7 @@ from monai.networks.blocks import Convolution from monai.networks.layers import Act +from monai.networks.utils import normal_init class MultiScalePatchDiscriminator(nn.Sequential): @@ -211,7 +212,7 @@ def __init__( ), ) - self.apply(self.initialise_weights) + self.apply(normal_init) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ @@ -227,21 +228,3 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: out.append(intermediate_output) return out[1:] - - def initialise_weights(self, m: nn.Module) -> None: - """ - Initialise weights of Convolution and BatchNorm layers. - - Args: - m: instance of torch.nn.module (or of class inheriting torch.nn.module) - """ - classname = m.__class__.__name__ - if classname.find("Conv2d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv3d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv1d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) From a052c4435a088a1e87f2d41279822109edbd59fd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 14 May 2024 10:35:45 +0100 Subject: [PATCH 17/32] Only have contigous calls after attention blocks (#7763) Towards #7227 . ### Description There were lots of contigous calls in the DiffusionModelUnet. It turns out these are necessary after attention blocks, as the einops operation sometimes leads to non-contigous tensors that can cause errors. I've tidied the code up so the .contiguous calls are only after attention calls. A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/nets/diffusion_model_unet.py | 21 +++++++------------ .../nets/spade_diffusion_model_unet.py | 9 +++----- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 38d7f816a9..f995d20e54 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch class SpatialTransformer(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image. @@ -396,14 +392,11 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x.contiguous() + h = x h = self.norm1(h) h = self.nonlinearity(h) if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() x = self.upsample(x) h = self.upsample(h) elif self.downsample is not None: @@ -609,7 +602,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() output_states.append(hidden_states) if self.downsampler is not None: @@ -726,7 +719,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) + hidden_states = attn(hidden_states, context=context).contiguous() output_states.append(hidden_states) if self.downsampler is not None: @@ -790,7 +783,7 @@ def forward( ) -> torch.Tensor: del context hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) + hidden_states = self.attention(hidden_states).contiguous() hidden_states = self.resnet_2(hidden_states, temb) return hidden_states @@ -1091,7 +1084,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -1669,7 +1662,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) + h = self.middle_block(hidden_states=h, temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: @@ -1682,7 +1675,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h.contiguous()) + output: torch.Tensor = self.out(h) return output diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index e019d21c11..594b8068af 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc h = self.nonlinearity(h) if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() x = self.upsample(x) h = self.upsample(h) elif self.downsample is not None: @@ -430,7 +427,7 @@ def forward( res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -568,7 +565,7 @@ def forward( res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states, context=context) + hidden_states = attn(hidden_states, context=context).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -919,7 +916,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) + h = self.middle_block(hidden_states=h, temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: From a423bcd5cb9bc2537960189bb8be42ab29e210e7 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 22 May 2024 16:17:13 +0100 Subject: [PATCH 18/32] Neater use off nn.Sequential in controlnet (#7754) Part of #7227 . ### Description Tidies up some of controlnet A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/conf.py | 48 ++++++++++++++++++++- monai/networks/nets/controlnet.py | 65 ++++++++++++++++++++++++++--- monai/visualize/utils.py | 4 +- requirements-dev.txt | 2 +- setup.cfg | 4 +- tests/test_controlnet.py | 33 +++++++++++++++ tests/test_matshow3d.py | 1 + tests/testing_data/data_config.json | 5 +++ 8 files changed, 149 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index fdb10fbe03..827626d12e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,8 @@ import os import subprocess import sys +import importlib +import inspect sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) @@ -137,7 +139,7 @@ def generate_apidocs(*args): "github_user": "Project-MONAI", "github_repo": "MONAI", "github_version": "dev", - "doc_path": "docs/", + "doc_path": "docs/source", "conf_py_path": "/docs/", "VERSION": version, } @@ -162,3 +164,47 @@ def setup(app): # Hook to allow for automatic generation of API docs # before doc deployment begins. app.connect("builder-inited", generate_apidocs) + + +# -- Linkcode configuration -------------------------------------------------- +DEFAULT_REPOSITORY = "Project-MONAI/MONAI" +repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY) + +base_code_url = f"https://github.com/{repository}/blob/{version}" +MODULE_ROOT_FOLDER = "monai" + + +# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py +def linkcode_resolve(domain, info): + if domain != "py": + raise ValueError( + f"expected domain to be 'py', got {domain}." + "Please adjust linkcode_resolve to either handle this domain or ignore it." + ) + + mod = importlib.import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(mod, objname) + try: + # object is a method of a class + obj = getattr(obj, attrname) + except AttributeError: + # object is an attribute of a class + return None + else: + obj = getattr(mod, info["fullname"]) + + try: + file = inspect.getsourcefile(obj) + source, lineno = inspect.getsourcelines(obj) + except TypeError: + # e.g. object is a typing.Union + return None + file = os.path.relpath(file, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + if not file.startswith(MODULE_ROOT_FOLDER): + # e.g. object is a typing.NewType + return None + start, end = lineno, lineno + len(source) - 1 + url = f"{base_code_url}/{file}#L{start}-L{end}" + return url diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 7450c87314..fe6746e017 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -34,7 +34,6 @@ from collections.abc import Sequence import torch -import torch.nn.functional as F from torch import nn from monai.networks.blocks import Convolution @@ -57,7 +56,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=1, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) self.blocks = nn.ModuleList([]) @@ -73,7 +73,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=1, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) ) @@ -85,7 +86,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=2, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) ) @@ -103,11 +105,9 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann def forward(self, conditioning): embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) - embedding = F.silu(embedding) embedding = self.conv_out(embedding) @@ -410,3 +410,56 @@ def forward( mid_block_res_sample *= conditioning_scale return down_block_res_samples, mid_block_res_sample + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a ControlNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old ControlNet model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + + self.load_state_dict(new_state_dict) diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index f6718fe7a5..88c9a0d66a 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -24,11 +24,9 @@ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type if TYPE_CHECKING: - from matplotlib import cm from matplotlib import pyplot as plt else: plt, _ = optional_import("matplotlib", name="pyplot") - cm, _ = optional_import("matplotlib", name="cm") __all__ = ["matshow3d", "blend_images"] @@ -210,7 +208,7 @@ def blend_images( image = repeat(image, 3, axis=0) def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor: - _cmap = cm.get_cmap(cmap) + _cmap = plt.colormaps.get_cmap(cmap) label_np, *_ = convert_data_type(label, np.ndarray) label_rgb_np = _cmap(label_np[0]) label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] diff --git a/requirements-dev.txt b/requirements-dev.txt index ce28d3ebe2..35ff3382be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,7 +36,7 @@ einops transformers>=4.36.0 mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 -matplotlib!=3.5.0 +matplotlib>=3.6.3 tensorboardX types-PyYAML pyyaml diff --git a/setup.cfg b/setup.cfg index c8ae1630f7..c90b043c1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,7 +68,7 @@ all = transformers<4.22; python_version <= '3.10' mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 - matplotlib + matplotlib>=3.6.3 tensorboardX pyyaml fire @@ -127,7 +127,7 @@ transformers = mlflow = mlflow>=1.28.0, <=2.11.3 matplotlib = - matplotlib + matplotlib>=3.6.3 clearml = clearml tensorboardX = diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 05ceb69fa3..4746c7ce22 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -11,15 +11,19 @@ from __future__ import annotations +import os +import tempfile import unittest from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets.controlnet import ControlNet from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config _, has_einops = optional_import("einops") UNCOND_CASES_2D = [ @@ -177,6 +181,35 @@ def test_shape_conditioned_models(self, input_param, expected_output_shape): self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) self.assertEqual(result[1].shape, expected_output_shape) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = ControlNet( + spatial_dims=2, + in_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + resblock_updown=True, + ) + + tmpdir = tempfile.mkdtemp() + key = "controlnet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "controlnet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index e54bb523e4..2eba310f4e 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -114,6 +114,7 @@ def test_3d_rgb(self): every_n=2, frame_dim=-1, channel_dim=0, + fill_value=0, show=False, ) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 318331e5f7..8b1d2868b7 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -153,6 +153,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", "hash_type": "sha256", "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" + }, + "controlnet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth", + "hash_type": "sha256", + "hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e" } }, "configs": { From 36511cc93e941e621b8a582b28b20cc2ec81f033 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:37:13 +0100 Subject: [PATCH 19/32] Addition of SPADE Network + tests and modification of SPADE normalisation (#7775) - spade_network, and SPADENet (VAE-GAN) - test_spade_vaegan (to test previously mentioned model) Modification of: - spade_diffusion_model_unet.py to change namings. - SPADE normalisation layer, to use get_norm_layer function instead of defining such layers directly. Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Mark Graham Signed-off-by: virginiafdez Co-authored-by: virginiafdez Co-authored-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/spade_norm.py | 7 +- monai/networks/nets/__init__.py | 1 + .../nets/spade_diffusion_model_unet.py | 8 +- monai/networks/nets/spade_network.py | 435 ++++++++++++++++++ tests/test_spade_vaegan.py | 140 ++++++ 5 files changed, 583 insertions(+), 8 deletions(-) create mode 100644 monai/networks/nets/spade_network.py create mode 100644 tests/test_spade_vaegan.py diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py index 8e082defe0..343dfa9ec0 100644 --- a/monai/networks/blocks/spade_norm.py +++ b/monai/networks/blocks/spade_norm.py @@ -15,7 +15,8 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import ADN, Convolution +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_norm_layer class SPADE(nn.Module): @@ -50,9 +51,7 @@ def __init__( norm_params = {} if len(norm_params) != 0: norm = (norm, norm_params) - self.param_free_norm = ADN( - act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc - ) + self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc) self.mlp_shared = Convolution( spatial_dims=spatial_dims, in_channels=label_nc, diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9101ab862e..c777fe6442 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -110,6 +110,7 @@ ) from .spade_autoencoderkl import SPADEAutoencoderKL from .spade_diffusion_model_unet import SPADEDiffusionModelUNet +from .spade_network import SPADENet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index 594b8068af..75d1687df3 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -53,7 +53,7 @@ __all__ = ["SPADEDiffusionModelUNet"] -class SPADEResnetBlock(nn.Module): +class SPADEDiffResBlock(nn.Module): """ Residual block with timestep conditioning and SPADE norm. Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) @@ -235,7 +235,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -353,7 +353,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -488,7 +488,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py new file mode 100644 index 0000000000..9164541f27 --- /dev/null +++ b/monai/networks/nets/spade_network.py @@ -0,0 +1,435 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.layers import Act +from monai.networks.layers.utils import get_act_layer +from monai.utils.enums import StrEnum + +__all__ = ["SPADENet"] + + +class UpsamplingModes(StrEnum): + bicubic = "bicubic" + nearest = "nearest" + bilinear = "bilinear" + + +class SPADENetResBlock(nn.Module): + """ + Creates a Residual Block with SPADE normalisation. + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks + spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks + norm: base normalisation type used on top of SPADE + kernel_size: convolutional kernel size + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.int_channels = min(self.in_channels, self.out_channels) + self.learned_shortcut = self.in_channels != self.out_channels + self.conv_0 = Convolution( + spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None + ) + self.conv_1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.int_channels, + out_channels=self.out_channels, + act=None, + norm=None, + ) + self.activation = get_act_layer(act) + self.norm_0 = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + self.norm_1 = SPADE( + label_nc=label_nc, + norm_nc=self.int_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + if self.learned_shortcut: + self.conv_s = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + act=None, + norm=None, + kernel_size=1, + ) + self.norm_s = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.activation(self.norm_0(x, seg))) + dx = self.conv_1(self.activation(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + +class SPADEEncoder(nn.Module): + """ + Encoding branch of a VAE compatible with a SPADE-like generator + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + z_dim: latent space dimension of the VAE containing the image sytle information + channels: number of output after each downsampling block + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + of the autoencoder (HxWx[D]) + kernel_size: convolutional kernel size + norm: normalisation layer type + act: activation type + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + z_dim: int, + channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.input_shape = input_shape + self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape] + blocks = [] + ch_init = self.in_channels + for _, ch_value in enumerate(channels): + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=ch_init, + out_channels=ch_value, + strides=2, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + ) + ch_init = ch_value + + self.blocks = nn.ModuleList(blocks) + self.fc_mu = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + self.fc_var = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return mu, logvar + + def encode(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return self.reparameterize(mu, logvar) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + + +class SPADEDecoder(nn.Module): + """ + Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, + behaving like a GAN, or coupled to a SPADE encoder. + + Args: + label_nc: number of semantic labels + spatial_dims: number of spatial dimensions + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: list[int], + z_dim: int | None = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_vae = is_vae + self.out_channels = out_channels + self.label_nc = label_nc + self.num_channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] + + if not self.is_vae: + self.conv_init = Convolution( + spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size + ) + elif self.is_vae and z_dim is None: + raise ValueError( + "If the network is used in VAE-GAN mode, parameter z_dim " + "(number of latent channels in the VAE) must be populated." + ) + else: + self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0]) + + self.z_dim = z_dim + blocks = [] + channels.append(self.out_channels) + self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) + for ch_ind, ch_value in enumerate(channels[:-1]): + blocks.append( + SPADENetResBlock( + spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=channels[ch_ind + 1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size, + act=act, + ) + ) + + self.blocks = torch.nn.ModuleList(blocks) + self.last_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + padding=(kernel_size - 1) // 2, + kernel_size=kernel_size, + norm=None, + act=last_act, + ) + + def forward(self, seg, z: torch.Tensor | None = None): + """ + Args: + seg: input BxCxHxW[xD] semantic map on which the output is conditioned on + z: latent vector output by the encoder if self.is_vae is True. When is_vae is + False, z is a random noise vector. + + Returns: + + """ + if not self.is_vae: + x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) + x = self.conv_init(x) + else: + if ( + z is None and self.z_dim is not None + ): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well. + z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device()) + x = self.fc(z) + x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) + + for res_block in self.blocks: + x = res_block(x, seg) + x = self.upsampling(x) + + x = self.last_conv(x) + return x + + +class SPADENet(nn.Module): + """ + SPADE Network, implemented based on the code by Park, T et al. in + "Semantic Image Synthesis with Spatially-Adaptive Normalization" + (https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: list[int], + z_dim: int | None = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_vae = is_vae + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.label_nc = label_nc + self.input_shape = input_shape + + if self.is_vae: + if z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + else: + self.encoder = SPADEEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + channels=channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + + decoder_channels = channels + decoder_channels.reverse() + + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + out_channels=out_channels, + label_nc=label_nc, + input_shape=input_shape, + channels=decoder_channels, + z_dim=z_dim, + is_vae=is_vae, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + act=act, + last_act=last_act, + kernel_size=kernel_size, + upsampling_mode=upsampling_mode, + ) + + def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): + z = None + if self.is_vae: + z_mu, z_logvar = self.encoder(x) + z = self.encoder.reparameterize(z_mu, z_logvar) + return self.decoder(seg, z), z_mu, z_logvar + else: + return (self.decoder(seg, z),) + + def encode(self, x: torch.Tensor): + if self.is_vae: + return self.encoder.encode(x) + else: + return None + + def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): + return self.decoder(seg, z) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py new file mode 100644 index 0000000000..3fdb9b74cb --- /dev/null +++ b/tests/test_spade_vaegan.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADENet + +CASE_2D = [ + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]], + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], None, False]], +] +CASE_3D = [ + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]], + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], None, False]], +] + + +def create_semantic_data(shape: list, semantic_regions: int): + """ + To create semantic and image mock inputs for the network. + Args: + shape: input shape + semantic_regions: number of semantic region + Returns: + """ + out_label = torch.zeros(shape) + out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 + for i in range(1, semantic_regions): + shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] + if len(shape) == 2: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = (base_intensity + torch.randn(shape_square) * 0.1) + elif len(shape) == 3: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = (base_intensity + torch.randn(shape_square) * 0.1) + else: + ValueError("Supports only 2D and 3D tensors") + + # One hot encode label + out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) + for ch in range(semantic_regions): + out_label_[ch, ...] = out_label == ch + + return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) + + +class TestSpadeNet(unittest.TestCase): + @parameterized.expand(CASE_2D) + def test_forward_2d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + if not net.is_vae: + out = net(in_label, in_image) + out = out[0] + else: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertEqual(list(out.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_2D) + def test_encoder_decoder(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out_z = net.encode(in_image) + if net.is_vae: + self.assertEqual(list(out_z.shape), [1, 16]) + else: + self.assertEqual(out_z, None) + out_i = net.decode(in_label, out_z) + self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_3D) + def test_forward_3d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + if net.is_vae: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + else: + out = net(in_label, in_image) + out = out[0] + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) + + def test_shape_wrong(self): + """ + We input an input shape that isn't divisible by 2**(n downstream steps) + """ + with self.assertRaises(ValueError): + _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + + +if __name__ == "__main__": + unittest.main() From 98550c028a313e7b86c0497e08fb09df45dfb091 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:09:50 +0100 Subject: [PATCH 20/32] Scheduler Clip Fix (#7855) Fixes # . ### Description Fixes a bug in the inferer and adds clipping parameters to the DDIM/DDPM schedulers. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: virginiafdez Co-authored-by: virginiafdez --- monai/inferers/inferer.py | 2 +- monai/networks/schedulers/ddim.py | 13 +++++++++++-- monai/networks/schedulers/ddpm.py | 9 ++++++++- requirements.txt | 2 +- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 72bcb8fd5a..769b6cc0e7 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1607,7 +1607,7 @@ def __init__( self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) - self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) def __call__( # type: ignore[override] self, diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 78e3cc2a0c..19e24d94b8 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -57,6 +57,8 @@ class DDIMScheduler(Scheduler): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True schedule_args: arguments to pass to the schedule function """ @@ -69,6 +71,8 @@ def __init__( set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = DDIMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, **schedule_args, ) -> None: super().__init__(num_train_timesteps, schedule, **schedule_args) @@ -90,6 +94,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] self.steps_offset = steps_offset # default the number of inference timesteps to the number of train steps @@ -193,7 +198,9 @@ def step( # 4. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -266,7 +273,9 @@ def reversed_step( # 4. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index a5173a1b65..93ad833031 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -77,6 +77,8 @@ class DDPMScheduler(Scheduler): variance_type: member of DDPMVarianceType clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True schedule_args: arguments to pass to the schedule function """ @@ -87,6 +89,8 @@ def __init__( variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, prediction_type: str = DDPMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, **schedule_args, ) -> None: super().__init__(num_train_timesteps, schedule, **schedule_args) @@ -98,6 +102,7 @@ def __init__( raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] self.variance_type = variance_type self.prediction_type = prediction_type @@ -219,7 +224,9 @@ def step( # 3. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf diff --git a/requirements.txt b/requirements.txt index 1569646794..1d6ae13eec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.9 -numpy>=1.20 +numpy>=1.20,<2.0 From 15ff66397d3eedf696855b92c5b66ba0ab624471 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Tue, 2 Jul 2024 05:57:06 +0100 Subject: [PATCH 21/32] Merging Dev Into gen-ai-dev and Undeclared Variable Fixes (#7887) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This merges `dev` into `gen-ai-dev` and modifies/fixes some places in code to remove undeclared variable errors raised by pylint. This should all be non-breaking and not changing anything functional in this branch. When completed, the branch should be ready for merging into `dev`. This also resolves various conflicts which need to be checked for correctness. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: binliu Signed-off-by: dependabot[bot] Signed-off-by: axel.vlaminck Signed-off-by: monai-bot Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: Suraj Pai Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: Ishan Dutta Signed-off-by: John Zielke Signed-off-by: Mingxin Zheng Signed-off-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Signed-off-by: Yiheng Wang Signed-off-by: Szabolcs Botond Lorincz Molnar Signed-off-by: Lucas Robinet Signed-off-by: Mingxin Signed-off-by: Han Wang Signed-off-by: Konstantin Sukharev Signed-off-by: Ben Murray Signed-off-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Signed-off-by: Peter Kaplinsky Signed-off-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Signed-off-by: NabJa Signed-off-by: Eric Kerfoot Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: axel.vlaminck Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl Co-authored-by: Suraj Pai Co-authored-by: Juampa <1523654+juampatronics@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Ishan Dutta Co-authored-by: johnzielke Co-authored-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond Co-authored-by: Nic Ma Co-authored-by: Lucas Robinet Co-authored-by: Han Wang Co-authored-by: Konstantin Sukharev <50718389+k-sukharev@users.noreply.github.com> Co-authored-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Co-authored-by: Pkaps25 <43655728+Pkaps25@users.noreply.github.com> Co-authored-by: Peter Kaplinsky Co-authored-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Co-authored-by: NabJa <32510324+NabJa@users.noreply.github.com> Co-authored-by: Yu <146002968+Yu0610@users.noreply.github.com> --- .github/workflows/blossom-ci.yml | 16 ++- .github/workflows/pythonapp.yml | 1 + .github/workflows/release.yml | 3 +- CHANGELOG.md | 95 +++++++++++++- CITATION.cff | 4 +- README.md | 1 - docs/requirements.txt | 4 +- docs/source/conf.py | 21 ++- docs/source/networks.rst | 1 - monai/apps/auto3dseg/auto_runner.py | 4 +- monai/apps/detection/utils/anchor_utils.py | 4 +- monai/apps/pathology/transforms/post/array.py | 1 + monai/bundle/utils.py | 1 + monai/bundle/workflows.py | 1 - monai/data/dataset.py | 52 +++----- monai/data/dataset_summary.py | 1 + monai/data/image_reader.py | 2 +- monai/data/torchscript_utils.py | 2 +- monai/data/utils.py | 15 ++- monai/networks/blocks/spatialattention.py | 2 +- monai/networks/nets/daf3d.py | 28 ++-- monai/networks/nets/quicknat.py | 2 + monai/networks/nets/resnet.py | 70 ++++++---- monai/networks/nets/swin_unetr.py | 7 +- monai/networks/schedulers/ddim.py | 8 ++ monai/transforms/__init__.py | 1 + monai/transforms/regularization/array.py | 59 ++++++--- monai/transforms/regularization/dictionary.py | 80 ++++++++---- monai/transforms/utils.py | 65 +++++++++ .../transforms/utils_create_transform_ims.py | 6 +- requirements-dev.txt | 4 +- setup.cfg | 8 +- tests/hvd_evenly_divisible_all_gather.py | 8 +- tests/test_arraydataset.py | 2 +- tests/test_clip_intensity_percentiles.py | 89 +++++++------ tests/test_clip_intensity_percentilesd.py | 69 +++++----- tests/test_controlnet_inferers.py | 21 +++ tests/test_dataset.py | 68 +++++++++- tests/test_ensure_channel_first.py | 7 +- tests/test_ensure_channel_firstd.py | 7 +- .../test_evenly_divisible_all_gather_dist.py | 8 +- tests/test_handler_metrics_saver_dist.py | 4 +- tests/test_hilbert_transform.py | 123 ++++-------------- tests/test_integration_unet_2d.py | 1 + tests/test_latent_diffusion_inferer.py | 16 +++ .../test_map_and_generate_sampling_centers.py | 87 +++++++++++++ tests/test_pad_collation.py | 6 +- tests/test_profiling.py | 4 +- tests/test_reg_loss_integration.py | 3 + tests/test_regularization.py | 80 ++++++++---- tests/test_resnet.py | 45 +++++-- tests/test_synthetic.py | 2 +- tests/test_vis_cam.py | 3 + tests/test_vis_gradcam.py | 2 + tests/test_warp.py | 1 + tests/utils.py | 1 + 56 files changed, 830 insertions(+), 396 deletions(-) create mode 100644 tests/test_map_and_generate_sampling_centers.py diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 1d6ee8a46c..bf507bab3b 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -29,13 +29,15 @@ jobs: args: ${{ env.args }} # This job only runs for pull request comments - if: contains('\ - Nic-Ma,\ - wyli,\ - pxLi,\ - YanxuanLiu,\ - KumoLiu,\ - ', format('{0},', github.actor)) && github.event.comment.body == '/build' + if: | + github.event.comment.body == '/build' && + ( + github.actor == 'Nic-Ma' || + github.actor == 'wyli' || + github.actor == 'pxLi' || + github.actor == 'YanxuanLiu' || + github.actor == 'KumoLiu' + ) steps: - name: Check if comment is issued by authorized person run: blossom-ci diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index b8b73907d4..d1e77bb567 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -100,6 +100,7 @@ jobs: python -m pip install --pre -U itk - name: Install the dependencies run: | + python -m pip install --user --upgrade pip wheel python -m pip install torch==1.13.1 torchvision==0.14.1 cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c134724665..60b610565e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -119,7 +119,8 @@ jobs: rm -rf {*,.[^.]*} release_tag_docker: - if: github.repository == 'Project-MONAI/MONAI' + # if: github.repository == 'Project-MONAI/MONAI' + if: ${{ false }} needs: versioning runs-on: ubuntu-latest steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 61be8f07c1..38336505ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,98 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## [1.3.1] - 2024-05-17 +### Added +* Support for `by_measure` argument in `RemoveSmallObjects` (#7137) +* Support for `pretrained` flag in `ResNet` (#7095) +* Support for uploading and downloading bundles to and from the Hugging Face Hub (#6454) +* Added weight parameter in DiceLoss to apply weight to voxels of each class (#7158) +* Support for returning dice for each class in `DiceMetric` (#7163) +* Introduced `ComponentStore` for storage purposes (#7159) +* Added utilities used in MONAI Generative (#7134) +* Enabled Python 3.11 support for `convert_to_torchscript` and `convert_to_onnx` (#7182) +* Support for MLflow in `AutoRunner` (#7176) +* `fname_regex` option in PydicomReader (#7181) +* Allowed setting AutoRunner parameters from config (#7175) +* `VoxelMorphUNet` and `VoxelMorph` (#7178) +* Enabled `cache` option in `GridPatchDataset` (#7180) +* Introduced `class_labels` option in `write_metrics_reports` for improved readability (#7249) +* `DiffusionLoss` for image registration task (#7272) +* Supported specifying `filename` in `Saveimage` (#7318) +* Compile support in `SupervisedTrainer` and `SupervisedEvaluator` (#7375) +* `mlflow_experiment_name` support in `Auto3DSeg` (#7442) +* Arm support (#7500) +* `BarlowTwinsLoss` for representation learning (#7530) +* `SURELoss` and `ConjugateGradient` for diffusion models (#7308) +* Support for `CutMix`, `CutOut`, and `MixUp` augmentation techniques (#7198) +* `meta_file` and `logging_file` options to `BundleWorkflow` (#7549) +* `properties_path` option to `BundleWorkflow` for customized properties (#7542) +* Support for both soft and hard clipping in `ClipIntensityPercentiles` (#7535) +* Support for not saving artifacts in `MLFlowHandler` (#7604) +* Support for multi-channel images in `PerceptualLoss` (#7568) +* Added ResNet backbone for `FlexibleUNet` (#7571) +* Introduced `dim_head` option in `SABlock` to set dimensions for each head (#7664) +* Direct links to github source code to docs (#7738, #7779) +#### misc. +* Refactored `list_data_collate` and `collate_meta_tensor` to utilize the latest PyTorch API (#7165) +* Added __str__ method in `Metric` base class (#7487) +* Made enhancements for testing files (#7662, #7670, #7663, #7671, #7672) +* Improved documentation for bundles (#7116) +### Fixed +#### transforms +* Addressed issue where lazy mode was ignored in `SpatialPadd` (#7316) +* Tracked applied operations in `ImageFilter` (#7395) +* Warnings are now given only if missing class is not set to 0 in `generate_label_classes_crop_centers` (#7602) +* Input is now always converted to C-order in `distance_transform_edt` to ensure consistent behavior (#7675) +#### data +* Modified .npz file behavior to use keys in `NumpyReader` (#7148) +* Handled corrupted cached files in `PersistentDataset` (#7244) +* Corrected affine update in `NrrdReader` (#7415) +#### metrics and losses +* Addressed precision issue in `get_confusion_matrix` (#7187) +* Harmonized and clarified documentation and tests for dice losses variants (#7587) +#### networks +* Removed hard-coded `spatial_dims` in `SwinTransformer` (#7302) +* Fixed learnable `position_embeddings` in `PatchEmbeddingBlock` (#7564, #7605) +* Removed `memory_pool_limit` in TRT config (#7647) +* Propagated `kernel_size` to `ConvBlocks` within `AttentionUnet` (#7734) +* Addressed hard-coded activation layer in `ResNet` (#7749) +#### bundle +* Resolved bundle download issue (#7280) +* Updated `bundle_root` directory for `NNIGen` (#7586) +* Checked for `num_fold` and failed early if incorrect (#7634) +* Enhanced logging logic in `ConfigWorkflow` (#7745) +#### misc. +* Enabled chaining in `Auto3DSeg` CLI (#7168) +* Addressed useless error message in `nnUNetV2Runner` (#7217) +* Resolved typing and deprecation issues in Mypy (#7231) +* Quoted `$PY_EXE` variable to handle Python path that contains spaces in Bash (#7268) +* Improved documentation, code examples, and warning messages in various modules (#7234, #7213, #7271, #7326, #7569, #7584) +* Fixed typos in various modules (#7321, #7322, #7458, #7595, #7612) +* Enhanced docstrings in various modules (#7245, #7381, #7746) +* Handled error when data is on CPU in `DataAnalyzer` (#7310) +* Updated version requirements for third-party packages (#7343, #7344, #7384, #7448, #7659, #7704, #7744, #7742, #7780) +* Addressed incorrect slice compute in `ImageStats` (#7374) +* Avoided editing a loop's mutable iterable to address B308 (#7397) +* Fixed issue with `CUDA_VISIBLE_DEVICES` setting being ignored (#7408, #7581) +* Avoided changing Python version in CICD (#7424) +* Renamed partial to callable in instantiate mode (#7413) +* Imported AttributeError for Python 3.12 compatibility (#7482) +* Updated `nnUNetV2Runner` to support nnunetv2 2.2 (#7483) +* Used uint8 instead of int8 in `LabelStats` (#7489) +* Utilized subprocess for nnUNet training (#7576) +* Addressed deprecated warning in ruff (#7625) +* Fixed downloading failure on FIPS machine (#7698) +* Updated `torch_tensorrt` compile parameters to avoid warning (#7714) +* Restrict `Auto3DSeg` fold input based on datalist (#7778) +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:24.03-py3` from `nvcr.io/nvidia/pytorch:23.08-py3` +### Removed +* Removed unrecommended star-arg unpacking after a keyword argument, addressed B026 (#7262) +* Skipped old PyTorch version test for `SwinUNETR` (#7266) +* Dropped docker build workflow and migrated to Nvidia Blossom system (#7450) +* Dropped Python 3.8 test on quick-py3 workflow (#7719) + ## [1.3.0] - 2023-10-12 ### Added * Intensity transforms `ScaleIntensityFixedMean` and `RandScaleIntensityFixedMean` (#6542) @@ -943,7 +1035,8 @@ the postprocessing steps should be used before calling the metrics methods [highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md -[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...HEAD +[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.3.1...HEAD +[1.3.1]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...1.3.1 [1.3.0]: https://github.com/Project-MONAI/MONAI/compare/1.2.0...1.3.0 [1.2.0]: https://github.com/Project-MONAI/MONAI/compare/1.1.0...1.2.0 [1.1.0]: https://github.com/Project-MONAI/MONAI/compare/1.0.1...1.1.0 diff --git a/CITATION.cff b/CITATION.cff index cac47faae4..4754c5b2e3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,8 +6,8 @@ title: "MONAI: Medical Open Network for AI" abstract: "AI Toolkit for Healthcare Imaging" authors: - name: "MONAI Consortium" -date-released: 2023-10-12 -version: "1.3.0" +date-released: 2024-05-21 +version: "1.3.1" identifiers: - description: "This DOI represents all versions of MONAI, and will always resolve to the latest one." type: doi diff --git a/README.md b/README.md index 7565fea1b7..5345cdb926 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ [![premerge](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml) [![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev) -[![docker](https://github.com/Project-MONAI/MONAI/actions/workflows/docker.yml/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/actions/workflows/docker.yml) [![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/) [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5acc437391..007281ac35 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,8 +21,8 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 -mlflow>=1.28.0, <=2.11.3 +transformers>=4.36.0, <4.41.0; python_version <= '3.10' +mlflow>=2.12.2 clearml>=1.10.0rc0 tensorboardX imagecodecs; platform_system == "Linux" or platform_system == "Darwin" diff --git a/docs/source/conf.py b/docs/source/conf.py index 827626d12e..a91f38081f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -97,7 +97,7 @@ def generate_apidocs(*args): "sphinx.ext.mathjax", "sphinx.ext.napoleon", "sphinx.ext.autodoc", - "sphinx.ext.viewcode", + "sphinx.ext.linkcode", "sphinx.ext.autosectionlabel", "sphinx.ext.autosummary", "sphinx_autodoc_typehints", @@ -140,7 +140,7 @@ def generate_apidocs(*args): "github_repo": "MONAI", "github_version": "dev", "doc_path": "docs/source", - "conf_py_path": "/docs/", + "conf_py_path": "/docs/source", "VERSION": version, } html_scaled_image_link = False @@ -167,11 +167,24 @@ def setup(app): # -- Linkcode configuration -------------------------------------------------- +DEFAULT_REF = "dev" +read_the_docs_ref = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", None) +if read_the_docs_ref: + # When building on ReadTheDocs, link to the specific commit + # https://docs.readthedocs.io/en/stable/reference/environment-variables.html#envvar-READTHEDOCS_GIT_IDENTIFIER + git_ref = read_the_docs_ref +elif os.environ.get("GITHUB_REF_TYPE", "branch") == "tag": + # When building a tag, link to the tag itself + git_ref = os.environ.get("GITHUB_REF", DEFAULT_REF) +else: + git_ref = os.environ.get("GITHUB_SHA", DEFAULT_REF) + DEFAULT_REPOSITORY = "Project-MONAI/MONAI" repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY) -base_code_url = f"https://github.com/{repository}/blob/{version}" +base_code_url = f"https://github.com/{repository}/blob/{git_ref}" MODULE_ROOT_FOLDER = "monai" +repo_root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) # Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py @@ -201,7 +214,7 @@ def linkcode_resolve(domain, info): except TypeError: # e.g. object is a typing.Union return None - file = os.path.relpath(file, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + file = os.path.relpath(file, repo_root_path) if not file.startswith(MODULE_ROOT_FOLDER): # e.g. object is a typing.NewType return None diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8321fed1a4..c51f5c88b1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -426,7 +426,6 @@ Layers .. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer :members: -======= `ConjugateGradient` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: ConjugateGradient diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 05c961f999..5b6b501555 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -499,8 +499,8 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner: if num_fold <= 0: raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}") - if num_fold > self.max_fold + 1: - # Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1 + if num_fold > self.max_fold: + # Auto3DSeg must contain validation set, so the maximum fold number is max_fold. raise ValueError( f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}." ) diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index 283169b653..cbde3ebae9 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -189,7 +189,7 @@ def generate_anchors( w_ratios = 1 / area_scale h_ratios = area_scale # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1] - elif self.spatial_dims == 3: + else: area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0) w_ratios = 1 / area_scale h_ratios = aspect_ratios_t[:, 0] / area_scale @@ -199,7 +199,7 @@ def generate_anchors( hs = (h_ratios[:, None] * scales_t[None, :]).view(-1) if self.spatial_dims == 2: base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0 - elif self.spatial_dims == 3: + else: # elif self.spatial_dims == 3: ds = (d_ratios[:, None] * scales_t[None, :]).view(-1) base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0 diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 99e94f89c0..0f57fb41cb 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -379,6 +379,7 @@ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> """ p_delta = (current[0] - previous[0], current[1] - previous[1]) + row, col = -1, -1 if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)): row = int(current[0] + 0.5) diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index a0f39d236f..0f17422ba5 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any raise ValueError(f"Cannot find config file '{full_cname}'") ardata = archive.read(full_cname) + cdata = {} if full_cname.lower().endswith("json"): cdata = json.loads(ardata, **load_kw_args) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index b42852cb0f..11c9bf0562 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -308,7 +308,6 @@ def __init__( super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path) self.config_root_path = config_root_path logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file - if logging_file is False: logger.warn(f"Logging file is set to {logging_file}, skipping logging.") else: diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 79e066303e..871b523289 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -36,15 +36,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing -from monai.transforms import ( - Compose, - Randomizable, - RandomizableTrait, - Transform, - apply_transform, - convert_to_contiguous, - reset_ops_id, -) +from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first @@ -77,15 +69,19 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Callable | None = None) -> None: + def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. - transform: a callable data transform on input data. - + transform: a callable, sequence of callables or None. If transform is not + a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences + of callables are applied in order and if `None` is passed, the data is returned as is. """ self.data = data - self.transform: Any = transform + try: + self.transform = Compose(transform) if not isinstance(transform, Compose) else transform + except Exception as e: + raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e def __len__(self) -> int: return len(self.data) @@ -95,7 +91,7 @@ def _transform(self, index: int): Fetch single data item from `self.data`. """ data_i = self.data[index] - return apply_transform(self.transform, data_i) if self.transform is not None else data_i + return self.transform(data_i) def __getitem__(self, index: int | slice | Sequence[int]): """ @@ -264,8 +260,6 @@ def __init__( using the cached content and with re-created transform instances. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func @@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed): random transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -346,9 +337,6 @@ def _post_transform(self, item_transformed): the transformed element through the random transforms """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed): Returns: the transformed element up to the N transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True) reset_ops_id(item_transformed) @@ -519,9 +504,6 @@ def _post_transform(self, item_transformed): Returns: the final transformed result """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - return self.transform(item_transformed, start=self.cache_n_trans) @@ -809,8 +791,6 @@ def __init__( Not following these recommendations may lead to runtime errors or duplicated cache across processes. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.set_num = cache_num # tracking the user-provided `cache_num` option self.set_rate = cache_rate # tracking the user-provided `cache_rate` option @@ -1282,8 +1262,10 @@ def to_list(x): data = [] for dataset in self.data: data.extend(to_list(dataset[index])) + if self.transform is not None: - data = apply_transform(self.transform, data, map_items=False) # transform the list data + self.transform.map_items = False # Compose object map_items to false so transform is applied to list + data = self.transform(data) # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists return tuple(data) @@ -1432,15 +1414,11 @@ def __len__(self): def _transform(self, index: int): data = {k: v[index] for k, v in self.arrays.items()} - - if not self.transform: - return data - - result = apply_transform(self.transform, data) + result = self.transform(data) if self.transform is not None else data if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)): return result - raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.") + raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.") class CSVDataset(Dataset): diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 769ae33b46..5b9e32afca 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -84,6 +84,7 @@ def collect_meta_data(self): """ for data in self.data_loader: + meta_dict = {} if isinstance(data[self.image_key], MetaTensor): meta_dict = data[self.image_key].meta elif self.meta_key in data: diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 2361bb63a7..f5e199e2a3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1331,7 +1331,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() - header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] + header[MetaKeys.SPATIAL_SHAPE] = header["sizes"].copy() [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header if self.channel_dim is None: # default to "no_channel" or -1 diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index cabf06ce89..507cf411d6 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -116,7 +116,7 @@ def load_net_with_metadata( Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ - extra_files = {f: "" for f in more_extra_files} + extra_files = dict.fromkeys(more_extra_files, "") extra_files[METADATA_FILENAME] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) diff --git a/monai/data/utils.py b/monai/data/utils.py index 585f02ec9e..7a08300abb 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -53,10 +53,6 @@ pytorch_after, ) -if pytorch_after(1, 13): - # import private code for reuse purposes, comment in case things break in the future - from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map - pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") nib, _ = optional_import("nibabel") @@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` and so should not be used as a collate function directly in dataloaders. """ - collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate - collated = collate_fn(batch) # type: ignore + if pytorch_after(1, 13): + from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues + + collated = collate_tensor_fn(batch) + else: + collated = default_collate(batch) + meta_dicts = [i.meta or TraceKeys.NONE for i in batch] common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) if common_: @@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence): if pytorch_after(1, 13): # needs to go here to avoid circular import + from torch.utils.data._utils.collate import default_collate_fn_map + from monai.data.meta_tensor import MetaTensor default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 020d8d23fd..75319853d9 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -68,7 +68,7 @@ def forward(self, x: torch.Tensor): h, w = x.shape[2], x.shape[3] rearrange_input = Rearrange("b c h w -> b (h w) c") rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - if self.spatial_dims == 3: + else: h, w, d = x.shape[2], x.shape[3], x.shape[4] rearrange_input = Rearrange("b c h w d -> b (h w d) c") rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) diff --git a/monai/networks/nets/daf3d.py b/monai/networks/nets/daf3d.py index c9a18c746a..02e5bb022a 100644 --- a/monai/networks/nets/daf3d.py +++ b/monai/networks/nets/daf3d.py @@ -13,6 +13,7 @@ from collections import OrderedDict from collections.abc import Callable, Sequence +from functools import partial import torch import torch.nn as nn @@ -25,6 +26,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_norm_layer from monai.networks.nets.resnet import ResNet, ResNetBottleneck __all__ = [ @@ -170,33 +172,37 @@ class Daf3dResNetBottleneck(ResNetBottleneck): spatial_dims: number of spatial dimensions of the input image. stride: stride to use for second conv layer. downsample: which downsample layer to use. + norm: which normalization layer to use. Defaults to group. """ expansion = 2 - def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): - norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + def __init__( + self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32}) + ): conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims) + # in case downsample uses batch norm, change to group norm if isinstance(downsample, nn.Sequential): downsample = nn.Sequential( conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False), - norm_type(num_groups=32, num_channels=planes * self.expansion), + norm_layer(channels=planes * self.expansion), ) super().__init__(in_planes, planes, spatial_dims, stride, downsample) # change norm from batch to group norm - self.bn1 = norm_type(num_groups=32, num_channels=planes) - self.bn2 = norm_type(num_groups=32, num_channels=planes) - self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion) + self.bn1 = norm_layer(channels=planes) + self.bn2 = norm_layer(channels=planes) + self.bn3 = norm_layer(channels=planes * self.expansion) # adapt second convolution to work with groups self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False) # adapt activation function - self.relu = nn.PReLU() # type: ignore + self.relu = nn.PReLU() class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck): @@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck): downsample: which downsample layer to use. """ - def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): - super().__init__(in_planes, planes, spatial_dims, stride, downsample) + def __init__( + self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32}) + ): + super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm) # add dilation in second convolution conv_type: Callable = Conv[Conv.CONV, spatial_dims] @@ -287,7 +295,7 @@ def __init__( n_input_channels, self.in_planes, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False ) self.bn1 = norm_type(32, 64) - self.relu = nn.PReLU() # type: ignore + self.relu = nn.PReLU() # adapt layers to our needs self.layer1 = self._make_layer(Daf3dResNetBottleneck, block_inplanes[0], layers[0], spatial_dims, shortcut_type) diff --git a/monai/networks/nets/quicknat.py b/monai/networks/nets/quicknat.py index cbcccf24d7..bbc4e7e490 100644 --- a/monai/networks/nets/quicknat.py +++ b/monai/networks/nets/quicknat.py @@ -168,6 +168,8 @@ def _get_layer(self, in_channels, out_channels, dilation): def forward(self, input, _): i = 0 result = input + result1 = input # this will not stay this value, needed here for pylint/mypy + for l in self.children(): # ignoring the max (un-)pool and droupout already added in the initial initialization step if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)): diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 74d15bc6bf..6e61db07ca 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -22,8 +22,8 @@ import torch.nn as nn from monai.networks.blocks.encoder import BaseEncoder -from monai.networks.layers.factories import Conv, Norm, Pool -from monai.networks.layers.utils import get_pool_layer +from monai.networks.layers.factories import Conv, Pool +from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option, optional_import @@ -57,7 +57,6 @@ "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False), } - logger = logging.getLogger(__name__) @@ -79,6 +78,8 @@ def __init__( spatial_dims: int = 3, stride: int = 1, downsample: nn.Module | partial | None = None, + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -87,17 +88,18 @@ def __init__( spatial_dims: number of spatial dimensions of the input image. stride: stride to use for first conv layer. downsample: which downsample layer to use. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) - self.bn1 = norm_type(planes) - self.relu = nn.ReLU(inplace=True) + self.bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes) + self.act = get_act_layer(name=act) self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes) self.downsample = downsample self.stride = stride @@ -106,7 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act(out) out = self.conv2(out) out = self.bn2(out) @@ -115,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act(out) return out @@ -130,6 +132,8 @@ def __init__( spatial_dims: int = 3, stride: int = 1, downsample: nn.Module | partial | None = None, + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -138,20 +142,22 @@ def __init__( spatial_dims: number of spatial dimensions of the input image. stride: stride to use for second conv layer. downsample: which downsample layer to use. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims) self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = norm_type(planes) + self.bn1 = norm_layer(channels=planes) self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = norm_layer(channels=planes) self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = norm_type(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) + self.bn3 = norm_layer(channels=planes * self.expansion) + self.act = get_act_layer(name=act) self.downsample = downsample self.stride = stride @@ -160,11 +166,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act(out) out = self.conv2(out) out = self.bn2(out) - out = self.relu(out) + out = self.act(out) out = self.conv3(out) out = self.bn3(out) @@ -173,7 +179,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act(out) return out @@ -203,6 +209,8 @@ class ResNet(nn.Module): num_classes: number of output (classifications). feed_forward: whether to add the FC layer for the output, default to `True`. bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ @@ -221,6 +229,8 @@ def __init__( num_classes: int = 400, feed_forward: bool = True, bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: super().__init__() @@ -233,7 +243,6 @@ def __init__( raise ValueError("Unknown block '%s', use basic or bottleneck" % block) conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] - norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims @@ -257,8 +266,10 @@ def __init__( padding=tuple(k // 2 for k in conv1_kernel_size), bias=False, ) - self.bn1 = norm_type(self.in_planes) - self.relu = nn.ReLU(inplace=True) + + norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes) + self.bn1 = norm_layer + self.act = get_act_layer(name=act) self.maxpool = pool_type(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type) self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2) @@ -270,7 +281,7 @@ def __init__( for m in self.modules(): if isinstance(m, conv_type): nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu") - elif isinstance(m, norm_type): + elif isinstance(m, type(norm_layer)): nn.init.constant_(torch.as_tensor(m.weight), 1) nn.init.constant_(torch.as_tensor(m.bias), 0) elif isinstance(m, nn.Linear): @@ -290,9 +301,9 @@ def _make_layer( spatial_dims: int, shortcut_type: str, stride: int = 1, + norm: str | tuple = "batch", ) -> nn.Sequential: conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] downsample: nn.Module | partial | None = None if stride != 1 or self.in_planes != planes * block.expansion: @@ -312,25 +323,30 @@ def _make_layer( stride=stride, bias=self.bias_downsample, ), - norm_type(planes * block.expansion), + get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion), ) layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, + norm=norm, ) ] self.in_planes = planes * block.expansion for _i in range(1, blocks): - layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) + layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm)) return nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.act(x) if not self.no_max_pool: x = self.maxpool(x) @@ -397,7 +413,7 @@ def forward(self, inputs: torch.Tensor): """ x = self.conv1(inputs) x = self.bn1(x) - x = self.relu(x) + x = self.act(x) features = [] features.append(x) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 6f96dfd291..3900c866b3 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -347,7 +347,7 @@ def window_partition(x, window_size): x: input tensor. window_size: local window size. """ - x_shape = x.size() + x_shape = x.size() # length 4 or 5 only if len(x_shape) == 5: b, d, h, w, c = x_shape x = x.view( @@ -363,10 +363,11 @@ def window_partition(x, window_size): windows = ( x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) ) - elif len(x_shape) == 4: + else: # if len(x_shape) == 4: b, h, w, c = x.shape x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) + return windows @@ -613,7 +614,7 @@ def forward_part1(self, x, mask_matrix): _, dp, hp, wp, _ = x.shape dims = [b, dp, hp, wp] - elif len(x_shape) == 4: + else: # elif len(x_shape) == 4 b, h, w, c = x.shape window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) pad_l = pad_t = 0 diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 19e24d94b8..2a0121d063 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -184,6 +184,10 @@ def step( beta_prod_t = 1 - alpha_prod_t + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.prediction_type == DDIMPredictionType.EPSILON: @@ -258,6 +262,10 @@ def reversed_step( beta_prod_t = 1 - alpha_prod_t + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ab9adb6a99..ef1da2d855 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -671,6 +671,7 @@ in_bounds, is_empty, is_positive, + map_and_generate_sampling_centers, map_binary_to_indices, map_classes_to_indices, map_spatial_axes, diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 0b495c8623..9186a5c46f 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -16,6 +16,9 @@ import torch +from monai.data.meta_obj import get_track_meta +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor + from ..transform import RandomizableTransform __all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] @@ -53,9 +56,11 @@ def randomize(self, data=None) -> None: as needed. You need to call this method everytime you apply the transform to a new batch. """ + super().randomize(None) self._params = ( torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), self.R.permutation(self.batch_size), + [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [], ) @@ -69,7 +74,7 @@ class MixUp(Mixer): """ def apply(self, data: torch.Tensor): - weight, perm = self._params + weight, perm, _ = self._params nsamples, *dims = data.shape if len(weight) != nsamples: raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") @@ -80,11 +85,21 @@ def apply(self, data: torch.Tensor): mixweight = weight[(Ellipsis,) + (None,) * len(dims)] return mixweight * data + (1 - mixweight) * data[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy + + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + if randomize: + self.randomize() if labels is None: - return self.apply(data) - return self.apply(data), self.apply(labels) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] + + return ( + convert_to_dst_type(self.apply(data_t), dst=data)[0], + convert_to_dst_type(self.apply(labels_t), dst=labels)[0], + ) class CutMix(Mixer): @@ -113,14 +128,13 @@ class CutMix(Mixer): """ def apply(self, data: torch.Tensor): - weights, perm = self._params + weights, perm, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 @@ -128,7 +142,7 @@ def apply(self, data: torch.Tensor): return mask * data + (1 - mask) * data[perm, ...] def apply_on_labels(self, labels: torch.Tensor): - weights, perm = self._params + weights, perm, _ = self._params nsamples, *dims = labels.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") @@ -136,10 +150,20 @@ def apply_on_labels(self, labels: torch.Tensor): mixweight = weights[(Ellipsis,) + (None,) * len(dims)] return mixweight * labels + (1 - mixweight) * labels[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() - augmented = self.apply(data) - return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + augmented_label = None + + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + if randomize: + self.randomize(data) + augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0] + + if labels is not None: + augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0] + + return (augmented, augmented_label) if labels is not None else augmented class CutOut(Mixer): @@ -155,20 +179,21 @@ class CutOut(Mixer): """ def apply(self, data: torch.Tensor): - weights, _ = self._params + weights, _, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 return mask * data - def __call__(self, data: torch.Tensor): - self.randomize() - return self.apply(data) + def __call__(self, data: torch.Tensor, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + if randomize: + self.randomize(data) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 373913da99..d8815e47b9 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -11,16 +11,23 @@ from __future__ import annotations +from collections.abc import Hashable + +import numpy as np + from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta +from monai.utils import convert_to_tensor from monai.utils.misc import ensure_tuple -from ..transform import MapTransform +from ..transform import MapTransform, RandomizableTransform from .array import CutMix, CutOut, MixUp __all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] -class MixUpd(MapTransform): +class MixUpd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.MixUp`. @@ -31,18 +38,24 @@ class MixUpd(MapTransform): def __init__( self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd: + super().set_random_state(seed, state) + self.mixup.set_random_state(seed, state) + return self + def __call__(self, data): - self.mixup.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixup.apply(data[k]) - return result + d = dict(data) + # all the keys share the same random state + self.mixup.randomize(None) + for k in self.key_iterator(d): + d[k] = self.mixup(data[k], randomize=False) + return d -class CutMixd(MapTransform): +class CutMixd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutMix`. @@ -63,17 +76,27 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] - def __call__(self, data): - self.mixer.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixer.apply(data[k]) - for k in self.label_keys: - result[k] = self.mixer.apply_on_labels(data[k]) - return result - + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd: + super().set_random_state(seed, state) + self.mixer.set_random_state(seed, state) + return self -class CutOutd(MapTransform): + def __call__(self, data): + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.mixer.randomize(d[first_key]) + for key, label_key in self.key_iterator(d, self.label_keys): + ret = self.mixer(data[key], data.get(label_key, None), randomize=False) + d[key] = ret[0] + if label_key in d: + d[label_key] = ret[1] + return d + + +class CutOutd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutOut`. @@ -84,12 +107,21 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo super().__init__(keys, allow_missing_keys) self.cutout = CutOut(batch_size) + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd: + super().set_random_state(seed, state) + self.cutout.set_random_state(seed, state) + return self + def __call__(self, data): - result = dict(data) - self.cutout.randomize() - for k in self.keys: - result[k] = self.cutout(data[k]) - return result + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.cutout.randomize(d[first_key]) + for k in self.key_iterator(d): + d[k] = self.cutout(data[k], randomize=False) + return d MixUpD = MixUpDict = MixUpd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 560dbac346..d8461d927b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -108,6 +108,7 @@ "in_bounds", "is_empty", "is_positive", + "map_and_generate_sampling_centers", "map_binary_to_indices", "map_classes_to_indices", "map_spatial_axes", @@ -368,6 +369,70 @@ def check_non_lazy_pending_ops( warnings.warn(msg) +def map_and_generate_sampling_centers( + label: NdarrayOrTensor, + spatial_size: Sequence[int] | int, + num_samples: int, + label_spatial_shape: Sequence[int] | None = None, + num_classes: int | None = None, + image: NdarrayOrTensor | None = None, + image_threshold: float = 0.0, + max_samples_per_class: int | None = None, + ratios: list[float | int] | None = None, + rand_state: np.random.RandomState | None = None, + allow_smaller: bool = False, + warn: bool = True, +) -> tuple[tuple]: + """ + Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates. + This calls `map_classes_to_indices` to get indices from `label`, gets the shape from `label_spatial_shape` + is given otherwise from the labels, calls `generate_label_classes_crop_centers`, and returns its results. + + Args: + label: use the label data to get the indices of every class. + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + max_samples_per_class: maximum length of indices in each class to reduce memory consumption. + Default is None, no subsampling. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. + Returns: + Tuple of crop centres + """ + if label is None: + raise ValueError("label must not be None.") + indices = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class) + + if label_spatial_shape is not None: + _shape = label_spatial_shape + elif isinstance(label, monai.data.MetaTensor): + _shape = label.peek_pending_shape() + else: + _shape = label.shape[1:] + + if _shape is None: + raise ValueError( + "label_spatial_shape or label with a known shape must be provided to infer the output spatial shape." + ) + centers = generate_label_classes_crop_centers( + spatial_size, num_samples, _shape, indices, ratios, rand_state, allow_smaller, warn + ) + + return ensure_tuple(centers) + + def map_binary_to_indices( label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0 ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 4b5990abd3..a29fd4dbf9 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -269,11 +269,9 @@ def update_docstring(code_path, transform_name): def pre_process_data(data, ndim, is_map, is_post): - """If transform requires 2D data, then convert to 2D""" + """If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension.""" if ndim == 2: - for k in keys: - data[k] = data[k][..., data[k].shape[-1] // 2] - + data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()} if is_map: return data return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE] diff --git a/requirements-dev.txt b/requirements-dev.txt index 35ff3382be..c50d9248df 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,8 +33,8 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers>=4.36.0 -mlflow>=1.28.0, <=2.11.3 +transformers>=4.36.0, <4.41.0; python_version <= '3.10' +mlflow>=2.12.2 clearml>=1.10.0rc0 matplotlib>=3.6.3 tensorboardX diff --git a/setup.cfg b/setup.cfg index c90b043c1c..7b82784a8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,8 +65,8 @@ all = imagecodecs pandas einops - transformers<4.22; python_version <= '3.10' - mlflow>=1.28.0, <=2.11.3 + transformers>=4.36.0, <4.41.0; python_version <= '3.10' + mlflow>=2.12.2 clearml>=1.10.0rc0 matplotlib>=3.6.3 tensorboardX @@ -123,9 +123,9 @@ pandas = einops = einops transformers = - transformers<4.22; python_version <= '3.10' + transformers>=4.36.0, <4.41.0; python_version <= '3.10' mlflow = - mlflow>=1.28.0, <=2.11.3 + mlflow>=2.12.2 matplotlib = matplotlib>=3.6.3 clearml = diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index 78c6ca06bc..732ad13b83 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -30,10 +30,10 @@ def test_data(self): self._run() def _run(self): - if hvd.rank() == 0: - data1 = torch.tensor([[1, 2], [3, 4]]) - data2 = torch.tensor([[1.0, 2.0]]) - data3 = torch.tensor(7) + # if hvd.rank() == 0: + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if hvd.rank() == 1: data1 = torch.tensor([[5, 6]]) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index efc014a267..b61b3c139c 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -41,7 +41,7 @@ class TestCompose(Compose): - def __call__(self, input_, lazy): + def __call__(self, input_, lazy=False): img = self.transforms[0](input_) metadata = img.meta img = self.transforms[1](img) diff --git a/tests/test_clip_intensity_percentiles.py b/tests/test_clip_intensity_percentiles.py index af157446f6..77f811db87 100644 --- a/tests/test_clip_intensity_percentiles.py +++ b/tests/test_clip_intensity_percentiles.py @@ -18,9 +18,32 @@ from monai.transforms import ClipIntensityPercentiles from monai.transforms.utils import soft_clip from monai.transforms.utils_pytorch_numpy_unification import clip, percentile +from monai.utils.type_conversion import convert_to_tensor from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose +def test_hard_clip_func(im, lower, upper): + im_t = convert_to_tensor(im) + if lower is None: + upper = percentile(im_t, upper) + elif upper is None: + lower = percentile(im_t, lower) + else: + lower, upper = percentile(im_t, (lower, upper)) + return clip(im_t, lower, upper) + + +def test_soft_clip_func(im, lower, upper): + im_t = convert_to_tensor(im) + if lower is None: + upper = percentile(im_t, upper) + elif upper is None: + lower = percentile(im_t, lower) + else: + lower, upper = percentile(im_t, (lower, upper)) + return soft_clip(im_t, minv=lower, maxv=upper, sharpness_factor=1.0, dtype=torch.float32) + + class TestClipIntensityPercentiles2D(NumpyImageTestCase2D): @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -28,8 +51,7 @@ def test_hard_clipping_two_sided(self, p): hard_clipper = ClipIntensityPercentiles(upper=95, lower=5) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (5, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 95) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -37,8 +59,7 @@ def test_hard_clipping_one_sided_high(self, p): hard_clipper = ClipIntensityPercentiles(upper=95, lower=None) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (0, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 0, 95) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -46,8 +67,7 @@ def test_hard_clipping_one_sided_low(self, p): hard_clipper = ClipIntensityPercentiles(upper=None, lower=5) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (5, 100)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 100) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -55,37 +75,35 @@ def test_soft_clipping_two_sided(self, p): soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - lower, upper = percentile(im, (5, 95)) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_high(self, p): soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - upper = percentile(im, 95) - expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) - # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=5e-5, atol=0) + expected = test_soft_clip_func(im, None, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_low(self, p): soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - lower = percentile(im, 5) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, None) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, p): clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True) im = p(self.imt) result = clipper(im) - for i, c in enumerate(im): + im_t = convert_to_tensor(self.imt) + for i, c in enumerate(im_t): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0) @@ -118,8 +136,7 @@ def test_hard_clipping_two_sided(self, p): hard_clipper = ClipIntensityPercentiles(upper=95, lower=5) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (5, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 95) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -127,8 +144,7 @@ def test_hard_clipping_one_sided_high(self, p): hard_clipper = ClipIntensityPercentiles(upper=95, lower=None) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (0, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 0, 95) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -136,8 +152,7 @@ def test_hard_clipping_one_sided_low(self, p): hard_clipper = ClipIntensityPercentiles(upper=None, lower=5) im = p(self.imt) result = hard_clipper(im) - lower, upper = percentile(im, (5, 100)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 100) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -145,37 +160,35 @@ def test_soft_clipping_two_sided(self, p): soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - lower, upper = percentile(im, (5, 95)) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_high(self, p): soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - upper = percentile(im, 95) - expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) - # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=5e-5, atol=0) + expected = test_soft_clip_func(im, None, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_low(self, p): soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper(im) - lower = percentile(im, 5) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result, p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, None) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, p): clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True) im = p(self.imt) result = clipper(im) - for i, c in enumerate(im): + im_t = convert_to_tensor(self.imt) + for i, c in enumerate(im_t): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0) diff --git a/tests/test_clip_intensity_percentilesd.py b/tests/test_clip_intensity_percentilesd.py index ed4fc588cb..3e06b18418 100644 --- a/tests/test_clip_intensity_percentilesd.py +++ b/tests/test_clip_intensity_percentilesd.py @@ -13,14 +13,15 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import ClipIntensityPercentilesd -from monai.transforms.utils import soft_clip from monai.transforms.utils_pytorch_numpy_unification import clip, percentile +from monai.utils.type_conversion import convert_to_tensor from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose +from .test_clip_intensity_percentiles import test_hard_clip_func, test_soft_clip_func + class TestClipIntensityPercentilesd2D(NumpyImageTestCase2D): @@ -30,8 +31,7 @@ def test_hard_clipping_two_sided(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (5, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 95) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -40,8 +40,7 @@ def test_hard_clipping_one_sided_high(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (0, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 0, 95) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -50,8 +49,7 @@ def test_hard_clipping_one_sided_low(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (5, 100)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 100) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -60,10 +58,9 @@ def test_soft_clipping_two_sided(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - lower, upper = percentile(im, (5, 95)) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_high(self, p): @@ -71,10 +68,9 @@ def test_soft_clipping_one_sided_high(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - upper = percentile(im, 95) - expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) - # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=5e-5, atol=0) + expected = test_soft_clip_func(im, None, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_low(self, p): @@ -82,10 +78,9 @@ def test_soft_clipping_one_sided_low(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - lower = percentile(im, 5) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, None) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, p): @@ -93,7 +88,8 @@ def test_channel_wise(self, p): clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True) im = p(self.imt) result = clipper({key: im}) - for i, c in enumerate(im): + im_t = convert_to_tensor(self.imt) + for i, c in enumerate(im_t): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-3, atol=0) @@ -132,8 +128,7 @@ def test_hard_clipping_two_sided(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (5, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 95) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -142,8 +137,7 @@ def test_hard_clipping_one_sided_high(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (0, 95)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 0, 95) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -152,8 +146,7 @@ def test_hard_clipping_one_sided_low(self, p): hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5) im = p(self.imt) result = hard_clipper({key: im}) - lower, upper = percentile(im, (5, 100)) - expected = clip(im, lower, upper) + expected = test_hard_clip_func(im, 5, 100) assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) @@ -162,10 +155,9 @@ def test_soft_clipping_two_sided(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - lower, upper = percentile(im, (5, 95)) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=upper, dtype=torch.float32) - # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + expected = test_soft_clip_func(im, 5, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_high(self, p): @@ -173,10 +165,9 @@ def test_soft_clipping_one_sided_high(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - upper = percentile(im, 95) - expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=torch.float32) - # the rtol is set to 5e-5 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=5e-5, atol=0) + expected = test_soft_clip_func(im, None, 95) + # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_soft_clipping_one_sided_low(self, p): @@ -184,10 +175,9 @@ def test_soft_clipping_one_sided_low(self, p): soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0) im = p(self.imt) result = soft_clipper({key: im}) - lower = percentile(im, 5) - expected = soft_clip(im, sharpness_factor=1.0, minv=lower, maxv=None, dtype=torch.float32) + expected = test_soft_clip_func(im, 5, None) # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy - assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-6, atol=0) + assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, p): @@ -195,7 +185,8 @@ def test_channel_wise(self, p): clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True) im = p(self.imt) result = clipper({key: im}) - for i, c in enumerate(im): + im_t = convert_to_tensor(im) + for i, c in enumerate(im_t): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 96e707acb5..e3b0aeb5a2 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -663,6 +663,8 @@ def test_prediction_shape( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -730,6 +732,8 @@ def test_sample_shape( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -793,6 +797,8 @@ def test_sample_intermediates( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -835,6 +841,10 @@ def test_sample_intermediates( controlnet=controlnet, cn_cond=mask, ) + + # TODO: this isn't correct, should the above produce intermediates as well? + # This test has always passed so is this branch not being used? + intermediates = None else: sample, intermediates = inferer.sample( input_noise=noise, @@ -846,6 +856,7 @@ def test_sample_intermediates( controlnet=controlnet, cn_cond=mask, ) + self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, input_shape) @@ -861,6 +872,8 @@ def test_get_likelihoods( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -929,6 +942,8 @@ def test_resample_likelihoods( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -999,6 +1014,8 @@ def test_prediction_shape_conditioned_concat( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -1080,6 +1097,8 @@ def test_sample_shape_conditioned_concat( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -1156,6 +1175,8 @@ def test_sample_shape_different_latents( input_shape, latent_shape, ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1398009c63..0d37ae2efd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,7 +23,7 @@ from parameterized import parameterized from monai.data import Dataset -from monai.transforms import Compose, LoadImaged, SimulateDelayd +from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] @@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self): data[0, 0:2, 0:2] = 1 +class TestTupleDataset(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + test_data = [ + (os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")), + (os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")), + ] + + test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)]) + + # Here test_transform is applied element by element for the tuple. + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) + + # Here test_transform is applied to the tuple as a whole. + test_transform = Compose( + [ + # LoadImage creates a channel-stacked image when applied to a tuple + LoadImage(), + # Get the channel-stacked image and the label + Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])), + ], + map_items=False, + ) + + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) + + class TestDatsesetWithLazy(unittest.TestCase): LOGGER_NAME = "a_logger_name" diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 0c9ad5869e..fe046a4cdf 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -50,9 +50,10 @@ class TestEnsureChannelFirst(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) @unittest.skipUnless(has_itk, "itk not installed") def test_load_nifti(self, input_param, filenames, original_channel_dim): - if original_channel_dim is None: - test_image = np.random.rand(8, 8, 8) - elif original_channel_dim == -1: + # if original_channel_dim is None + test_image = np.random.rand(8, 8, 8) + + if original_channel_dim == -1: test_image = np.random.rand(8, 8, 8, 1) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 63a437894b..e9effad951 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -35,9 +35,10 @@ class TestEnsureChannelFirstd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_load_nifti(self, input_param, filenames, original_channel_dim): - if original_channel_dim is None: - test_image = np.random.rand(8, 8, 8) - elif original_channel_dim == -1: + # if original_channel_dim is None: + test_image = np.random.rand(8, 8, 8) + + if original_channel_dim == -1: test_image = np.random.rand(8, 8, 8, 1) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index d6d26c7e23..f1d45ba48f 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -27,10 +27,10 @@ def test_data(self): self._run() def _run(self): - if dist.get_rank() == 0: - data1 = torch.tensor([[1, 2], [3, 4]]) - data2 = torch.tensor([[1.0, 2.0]]) - data3 = torch.tensor(7) + # if dist.get_rank() == 0 + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if dist.get_rank() == 1: data1 = torch.tensor([[5, 6]]) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 46c9ad27d7..2e12b08aa9 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -51,8 +51,10 @@ def _val_func(engine, batch): engine = Engine(_val_func) + # define here to ensure symbol always exists regardless of the following if conditions + data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] + if my_rank == 0: - data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 879a74969d..b91ba3f6b7 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -19,11 +19,11 @@ from monai.networks.layers import HilbertTransform from monai.utils import OptionalImportError -from tests.utils import SkipIfModule, SkipIfNoModule, skip_if_no_cuda +from tests.utils import SkipIfModule, SkipIfNoModule def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) + x = np.fft.fft(input_datum.cpu().numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) new_dims_before = kwargs["axis"] @@ -44,19 +44,15 @@ def create_expected_numpy_output(input_datum, **kwargs): # CPU TEST DATA cpu_input_data = {} -cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) -cpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) -) -cpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu) - .unsqueeze(0) - .unsqueeze(0) -) -cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0) +cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu)[None, None] +cpu_input_data["2D"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None, None] +cpu_input_data["3D"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu +)[None, None] +cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None] cpu_input_data["2D 2CH"] = torch.as_tensor( np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu -).unsqueeze(0) +)[None] # SINGLE-CHANNEL CPU VALUE TESTS @@ -97,64 +93,21 @@ def create_expected_numpy_output(input_datum, **kwargs): 1e-5, # absolute tolerance ] +TEST_CASES_CPU = [ + TEST_CASE_1D_SINE_CPU, + TEST_CASE_2D_SINE_CPU, + TEST_CASE_3D_SINE_CPU, + TEST_CASE_1D_2CH_SINE_CPU, + TEST_CASE_2D_2CH_SINE_CPU, +] + # GPU TEST DATA if torch.cuda.is_available(): gpu = torch.device("cuda") - - gpu_input_data = {} - gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) - gpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) - ) - gpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu) - .unsqueeze(0) - .unsqueeze(0) - ) - gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0) - gpu_input_data["2D 2CH"] = torch.as_tensor( - np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu - ).unsqueeze(0) - - # SINGLE CHANNEL GPU VALUE TESTS - - TEST_CASE_1D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_3D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["3D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS - - TEST_CASE_1D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] + TEST_CASES_GPU = [[args, image.to(gpu), exp_data, atol] for args, image, exp_data, atol in TEST_CASES_CPU] +else: + TEST_CASES_GPU = [] # TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py @@ -162,42 +115,10 @@ def create_expected_numpy_output(input_datum, **kwargs): @SkipIfNoModule("torch.fft") class TestHilbertTransformCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_1D_SINE_CPU, - TEST_CASE_2D_SINE_CPU, - TEST_CASE_3D_SINE_CPU, - TEST_CASE_1D_2CH_SINE_CPU, - TEST_CASE_2D_2CH_SINE_CPU, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).numpy() - np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) - - -@skip_if_no_cuda -@SkipIfNoModule("torch.fft") -class TestHilbertTransformGPU(unittest.TestCase): - - @parameterized.expand( - ( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ] - ), - skip_on_empty=True, - ) + @parameterized.expand(TEST_CASES_CPU + TEST_CASES_GPU) def test_value(self, arguments, image, expected_data, atol): result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).cpu().numpy() + result = np.squeeze(result.cpu().numpy()) np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index 918190775c..3b40682de0 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -35,6 +35,7 @@ def __getitem__(self, _unused_id): def __len__(self): return train_steps + net = None if net_name == "basicunet": net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) elif net_name == "unet": diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 065ebafd95..2e04ad6c5c 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -320,6 +320,8 @@ class TestDiffusionSamplingInferer(unittest.TestCase): def test_prediction_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -368,6 +370,8 @@ def test_prediction_shape( def test_sample_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -413,6 +417,8 @@ def test_sample_shape( def test_sample_intermediates( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -468,6 +474,8 @@ def test_sample_intermediates( def test_get_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -521,6 +529,8 @@ def test_get_likelihoods( def test_resample_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -576,6 +586,8 @@ def test_resample_likelihoods( def test_prediction_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -642,6 +654,8 @@ def test_prediction_shape_conditioned_concat( def test_sample_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -703,6 +717,8 @@ def test_sample_shape_conditioned_concat( def test_sample_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): + stage_1 = None + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": diff --git a/tests/test_map_and_generate_sampling_centers.py b/tests/test_map_and_generate_sampling_centers.py new file mode 100644 index 0000000000..ff74f974b9 --- /dev/null +++ b/tests/test_map_and_generate_sampling_centers.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from copy import deepcopy + +import numpy as np +from parameterized import parameterized + +from monai.transforms import map_and_generate_sampling_centers +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASE_1 = [ + # test Argmax data + { + "label": (np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "spatial_size": [2, 2, 2], + "num_samples": 2, + "label_spatial_shape": [3, 3, 3], + "num_classes": 3, + "image": None, + "ratios": [0, 1, 2], + "image_threshold": 0.0, + }, + tuple, + 2, + 3, +] + +TEST_CASE_2 = [ + { + "label": ( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "spatial_size": [2, 2, 2], + "num_samples": 1, + "ratios": None, + "label_spatial_shape": [3, 3, 3], + "image": None, + "image_threshold": 0.0, + }, + tuple, + 1, + 3, +] + + +class TestMapAndGenerateSamplingCenters(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_map_and_generate_sampling_centers(self, input_data, expected_type, expected_count, expected_shape): + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + input_data["label"] = p(input_data["label"]) + set_determinism(0) + result = map_and_generate_sampling_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + for x, y in zip(result[0], result[-1]): + assert_allclose(x, y, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 17f49611df..9d5012c9a3 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -89,7 +89,7 @@ def tearDown(self) -> None: @parameterized.expand(TESTS) def test_pad_collation(self, t_type, collate_method, transform): - if t_type == dict: + if t_type is dict: dataset = CacheDataset(self.dict_data, transform, progress=False) else: dataset = _Dataset(self.list_data, self.list_labels, transform) @@ -104,7 +104,7 @@ def test_pad_collation(self, t_type, collate_method, transform): loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: - if t_type == dict: + if t_type is dict: shapes = [] decollated_data = decollate_batch(data) for d in decollated_data: @@ -113,7 +113,7 @@ def test_pad_collation(self, t_type, collate_method, transform): self.assertTrue(len(output["image"].applied_operations), len(dataset.transform.transforms)) self.assertTrue(len(set(shapes)) > 1) # inverted shapes must be different because of random xforms - if t_type == dict: + if t_type is dict: batch_inverse = BatchInverseTransform(dataset.transform, loader) for data in loader: output = batch_inverse(data) diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 6bee7ba262..649d980ebf 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -35,6 +35,7 @@ def setUp(self): self.scale = mt.ScaleIntensity() self.scale_call_name = "ScaleIntensity.__call__" + self.compose_call_name = "Compose.__call__" self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)]) self.test_image = torch.rand(1, 16, 16, 16) self.pid = os.getpid() @@ -82,7 +83,7 @@ def test_profile_multithread(self): self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16)) results = wp.get_results() - self.assertSequenceEqual(list(results), [self.scale_call_name]) + self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name]) prs = results[self.scale_call_name] @@ -98,6 +99,7 @@ def test_profile_context(self): self.scale(self.test_image) results = wp.get_results() + self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"}) prs = results["context"] diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 1fb81689e6..8afc2da6ad 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -83,6 +83,9 @@ def forward(self, x): # initialize a SGD optimizer optimizer = optim.Adam(net.parameters(), lr=learning_rate) + # declare first for pylint + init_loss = None + # train the network for it in range(max_iter): # set the gradient to zero diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 4df60b9808..12d64637d5 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -13,28 +13,31 @@ import unittest +import numpy as np import torch -from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd -from monai.utils import set_determinism +from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd +from tests.utils import assert_allclose class TestMixup(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_mixup(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) mixup = MixUp(6, 1.0) + mixup.set_random_state(seed=0) output = mixup(sample) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample + (1 - mixweight) * sample[perm, ...] + assert_allclose(output, expected, type_test=False, atol=1e-7) with self.assertRaises(ValueError): MixUp(6, -0.5) @@ -52,8 +55,19 @@ def test_mixupd(self): t = torch.rand(*shape, dtype=torch.float32) sample = {"a": t, "b": t} mixup = MixUpd(["a", "b"], 6) + mixup.set_random_state(seed=0) output = mixup(sample) - self.assertTrue(torch.allclose(output["a"], output["b"])) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + self.assertEqual(output["a"].shape, sample["a"].shape) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...] + assert_allclose(output["a"], expected, type_test=False, atol=1e-7) + assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7) + # self.assertTrue(torch.allclose(output["a"], output["b"])) with self.assertRaises(ValueError): MixUpd(["k1", "k2"], 6, -0.5) @@ -61,17 +75,12 @@ def test_mixupd(self): class TestCutMix(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutmix(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutmix = CutMix(6, 1.0) + cutmix.set_random_state(seed=0) output = cutmix(sample) self.assertEqual(output.shape, sample.shape) self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) @@ -83,29 +92,50 @@ def test_cutmixd(self): label = torch.randint(0, 1, shape) sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + cutmix.set_random_state(seed=123) output = cutmix(sample) - # croppings are different on each application - self.assertTrue(not torch.allclose(output["a"], output["b"])) # but mixing of labels is not affected by it self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) class TestCutOut(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutout(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutout = CutOut(6, 1.0) + cutout.set_random_state(seed=123) output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]] + assert_allclose(weight, cutout._params[0]) + assert_allclose(perm, cutout._params[1]) + self.assertSequenceEqual(coords, cutout._params[2]) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + def test_cutoutd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + cutout = CutOutd(["a", "b"], 6, 1.0) + cutout.set_random_state(seed=123) + output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]] + assert_allclose(weight, cutout.cutout._params[0]) + assert_allclose(perm, cutout.cutout._params[1]) + self.assertSequenceEqual(coords, cutout.cutout._params[2]) + self.assertEqual(output["a"].shape, sample["a"].shape) if __name__ == "__main__": diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 5d34a32d8d..e873f1238a 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -107,6 +107,7 @@ "num_classes": 3, "conv1_t_size": [3], "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), }, (1, 2, 32), (1, 3), @@ -185,13 +186,46 @@ (1, 3), ] +TEST_CASE_8 = [ + { + "block": "bottleneck", + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), + }, + (1, 2, 32), + (1, 3), +] + +TEST_CASE_9 = [ # Layer norm + { + "block": ResNetBlock, + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), + "norm": ("layer", {"normalized_shape": (64, 32)}), + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] PRETRAINED_TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) PRETRAINED_TEST_CASES.append([model, *case]) -for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: +for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]: TEST_CASES.append([ResNet, *case]) TEST_SCRIPT_CASES = [ @@ -207,15 +241,6 @@ ] -CASE_EXTRACT_FEATURES = [ - ( - {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, - [1, 1, 64, 64, 64], - ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), - ) -] - - class TestResNet(unittest.TestCase): def setUp(self): diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 7db3c3e77a..4ab2144568 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -47,7 +47,7 @@ def test_create_test_image(self, dim, input_param, expected_img, expected_seg, e set_determinism(seed=0) if dim == 2: img, seg = create_test_image_2d(**input_param) - elif dim == 3: + else: # dim == 3 img, seg = create_test_image_3d(**input_param) self.assertEqual(img.shape, expected_shape) self.assertEqual(seg.max(), expected_max_cls) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index b641599af2..68b12de2f8 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -70,6 +70,8 @@ class TestClassActivationMap(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): + model = None + if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": @@ -80,6 +82,7 @@ def test_shape(self, input_data, expected_shape): model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) + device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 325b74b3ce..f77d916a5b 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -153,6 +153,8 @@ class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, cam_class, input_data, expected_shape): + model = None + if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) elif input_data["model"] == "densenet2d_bin": diff --git a/tests/test_warp.py b/tests/test_warp.py index 55f40764c3..0e5f2466db 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -217,6 +217,7 @@ def itk_warp(img, ddf): # warp warp_filter.SetDisplacementField(displacement_field) warp_filter.SetInput(itk_img) + warp_filter.Update() warped_img = warp_filter.GetOutput() warped_img = np.asarray(warped_img) diff --git a/tests/utils.py b/tests/utils.py index ea73a3ed81..d1939e590b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -156,6 +156,7 @@ def skip_if_downloading_fails(): "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out + "HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub ) ): raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download From 95c22007a30b3f3ab7ef65bb59facec7133fd39a Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 20:57:04 +0100 Subject: [PATCH 22/32] Resolving conflicts Signed-off-by: Eric Kerfoot --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9debaf08f..3fff6ed631 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace From 97ebde89dfaf447978847967de998613dcd9b10f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 20:57:49 +0100 Subject: [PATCH 23/32] Resolving conflicts Signed-off-by: Eric Kerfoot --- CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38336505ed..804508c262 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## [1.3.2] - 2024-06-25 +### Fixed +#### misc. +* Updated Numpy version constraint to < 2.0 (#7859) + ## [1.3.1] - 2024-05-17 ### Added * Support for `by_measure` argument in `RemoveSmallObjects` (#7137) @@ -1035,7 +1040,8 @@ the postprocessing steps should be used before calling the metrics methods [highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md -[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.3.1...HEAD +[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.3.2...HEAD +[1.3.2]: https://github.com/Project-MONAI/MONAI/compare/1.3.1...1.3.2 [1.3.1]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...1.3.1 [1.3.0]: https://github.com/Project-MONAI/MONAI/compare/1.2.0...1.3.0 [1.2.0]: https://github.com/Project-MONAI/MONAI/compare/1.1.0...1.2.0 From 72a7fa02c47c6977500d45ae08d496ad1b724342 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 21:24:24 +0100 Subject: [PATCH 24/32] Resolving conflicts Signed-off-by: Eric Kerfoot --- CITATION.cff | 4 +- docs/source/transforms.rst | 6 + monai/metrics/utils.py | 6 +- monai/networks/blocks/selfattention.py | 23 +- monai/transforms/regularization/array.py | 3 - monai/utils/module.py | 9 +- requirements-dev.txt | 3 +- tests/test_median_filter.py | 1 + ...est_ultrasound_confidence_map_transform.py | 907 +++++++++--------- 9 files changed, 489 insertions(+), 473 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 4754c5b2e3..b535a77a9f 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,8 +6,8 @@ title: "MONAI: Medical Open Network for AI" abstract: "AI Toolkit for Healthcare Imaging" authors: - name: "MONAI Consortium" -date-released: 2024-05-21 -version: "1.3.1" +date-released: 2024-06-26 +version: "1.3.2" identifiers: - description: "This DOI represents all versions of MONAI, and will always resolve to the latest one." type: doi diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8bd5bfd99f..a359821679 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -667,6 +667,12 @@ Post-processing :members: :special-members: __call__ +`Invert` +""""""""" +.. autoclass:: Invert + :members: + :special-members: __call__ + Regularization ^^^^^^^^^^^^^^ diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index e7057256fb..340e54a1d7 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -35,9 +35,9 @@ optional_import, ) -binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") -distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") -distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +binary_erosion, _ = optional_import("scipy.ndimage", name="binary_erosion") +distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt") +distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") __all__ = [ "ignore_background", diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 370ad38595..6563133e07 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -38,11 +38,7 @@ def __init__( dim_head: int | None = None, qkv_bias: bool = False, save_attn: bool = False, - causal: bool = False, - sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + dim_head: int | None = None, ) -> None: """ Args: @@ -53,14 +49,7 @@ def __init__( dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762). - sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. - For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. - input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. - attention_dtype: cast attention operations to this dtype. - + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. """ super().__init__() @@ -81,9 +70,11 @@ def __init__( raise ValueError("sequence_length is necessary for causal attention.") self.num_heads = num_heads - self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size - self.out_proj = nn.Linear(inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, inner_dim * 3, bias=qkv_bias) + self.dim_head = hidden_size // num_heads if dim_head is None else dim_head + self.inner_dim = self.dim_head * num_heads + + self.out_proj = nn.Linear(self.inner_dim, hidden_size) + self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 9186a5c46f..4bf6cff649 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -88,7 +88,6 @@ def apply(self, data: torch.Tensor): def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): data_t = convert_to_tensor(data, track_meta=get_track_meta()) labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy - if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) if randomize: @@ -153,7 +152,6 @@ def apply_on_labels(self, labels: torch.Tensor): def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): data_t = convert_to_tensor(data, track_meta=get_track_meta()) augmented_label = None - if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) if randomize: @@ -162,7 +160,6 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, rando if labels is not None: augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0] - return (augmented, augmented_label) if labels is not None else augmented diff --git a/monai/utils/module.py b/monai/utils/module.py index 6f301d8067..4d28f8d986 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -13,6 +13,7 @@ import enum import functools +import importlib.util import os import pdb import re @@ -208,9 +209,11 @@ def load_submodules( ): if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None: try: - mod = import_module(name) - importer.find_spec(name).loader.load_module(name) # type: ignore - submodules.append(mod) + mod_spec = importer.find_spec(name) # type: ignore + if mod_spec and mod_spec.loader: + mod = importlib.util.module_from_spec(mod_spec) + mod_spec.loader.exec_module(mod) + submodules.append(mod) except OptionalImportError: pass # could not import the optional deps., they are ignored except ImportError as e: diff --git a/requirements-dev.txt b/requirements-dev.txt index c50d9248df..1bba930273 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,4 +57,5 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub -opencv-python-headless +pyamg>=5.0.0 +git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py index bdfdf24f9f..02fa812380 100644 --- a/tests/test_median_filter.py +++ b/tests/test_median_filter.py @@ -28,6 +28,7 @@ def test_3d(self, input_tensor, radius): expected = input_tensor.numpy() output = filter(input_tensor).cpu().numpy() + np.testing.assert_allclose(output, expected, rtol=1e-5) def test_3d_radii(self): diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py index 63ce7d58e4..87c08b3ac3 100644 --- a/tests/test_ultrasound_confidence_map_transform.py +++ b/tests/test_ultrasound_confidence_map_transform.py @@ -11,11 +11,13 @@ from __future__ import annotations +import os import unittest import numpy as np import torch from parameterized import parameterized +from PIL import Image from monai.transforms import UltrasoundConfidenceMapTransform from tests.utils import assert_allclose @@ -32,7 +34,8 @@ [1, 2, 3, 32, 33, 34, 35, 1, 2, 3], [1, 2, 3, 36, 37, 38, 39, 1, 2, 3], [1, 2, 3, 40, 41, 42, 43, 1, 2, 3], - ] + ], + dtype=np.float32, ) TEST_MASK = np.array( @@ -47,474 +50,435 @@ [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] + ], + dtype=np.float32, ) SINK_ALL_OUTPUT = np.array( [ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [ - 0.97514489, - 0.96762971, - 0.96164186, - 0.95463443, - 0.9941512, - 0.99023054, - 0.98559401, - 0.98230057, - 0.96601224, - 0.95119599, - ], - [ - 0.92960533, - 0.92638451, - 0.9056675, - 0.9487176, - 0.9546961, - 0.96165853, - 0.96172303, - 0.92686401, - 0.92122613, - 0.89957239, - ], - [ - 0.86490963, - 0.85723665, - 0.83798141, - 0.90816201, - 0.90816097, - 0.90815301, - 0.9081427, - 0.85933627, - 0.85146935, - 0.82948586, - ], - [ - 0.77430346, - 0.76731372, - 0.74372311, - 0.89128774, - 0.89126885, - 0.89125066, - 0.89123521, - 0.76858589, - 0.76106647, - 0.73807776, - ], - [ - 0.66098109, - 0.65327697, - 0.63090644, - 0.33086588, - 0.3308383, - 0.33081937, - 0.33080718, - 0.6557468, - 0.64825099, - 0.62593375, - ], - [ - 0.52526945, - 0.51832586, - 0.49709412, - 0.25985059, - 0.25981009, - 0.25977729, - 0.25975222, - 0.52118958, - 0.51426328, - 0.49323164, - ], - [ - 0.3697845, - 0.36318971, - 0.34424661, - 0.17386804, - 0.17382046, - 0.17377993, - 0.17374668, - 0.36689317, - 0.36036096, - 0.3415582, - ], - [ - 0.19546374, - 0.1909659, - 0.17319999, - 0.08423318, - 0.08417993, - 0.08413242, - 0.08409104, - 0.19393909, - 0.18947485, - 0.17185031, + 0.8884930952884654, + 0.8626656901726876, + 0.8301161870669913, + 0.9757179300830185, + 0.9989819637626414, + 0.9994717624885747, + 0.9954377526794013, + 0.8898638133944221, + 0.862604343021387, + 0.8277862494812598, + ], + [ + 0.7765718877433174, + 0.7363731552518268, + 0.6871875923653379, + 0.9753673327387775, + 0.9893175316399789, + 0.9944181334242039, + 0.9936979128319371, + 0.7778001700035326, + 0.7362622619974832, + 0.6848377775329241, + ], + [ + 0.6648416226360719, + 0.6178079903692397, + 0.5630152545966568, + 0.8278402502498404, + 0.82790391019578, + 0.8289702087149963, + 0.8286730258710652, + 0.6658773633169731, + 0.6176836507071695, + 0.5609165245633834, + ], + [ + 0.5534420483956817, + 0.5055401989946189, + 0.451865872383879, + 0.7541423053657541, + 0.7544115886347456, + 0.7536884376055174, + 0.7524927915364896, + 0.5542943466824017, + 0.505422678400297, + 0.4502051549732117, + ], + [ + 0.4423657561928356, + 0.398221575954319, + 0.35030055029978124, + 0.4793202144786371, + 0.48057175662074125, + 0.4812057229564038, + 0.48111949176149327, + 0.44304092606050766, + 0.39812149713417405, + 0.34902458531143377, + ], + [ + 0.3315561576450342, + 0.29476346732036784, + 0.2558303772864961, + 0.35090405668257535, + 0.3515225984307705, + 0.35176548159366317, + 0.3516979775419521, + 0.33205839061494885, + 0.2946859567272435, + 0.2549042599220772, + ], + [ + 0.22094175240967673, + 0.19431840633358133, + 0.16672448058324435, + 0.22716195845848167, + 0.22761996456848282, + 0.22782525614780919, + 0.22781876632199002, + 0.22127471252104777, + 0.19426593309729956, + 0.16612306610996525, + ], + [ + 0.11044782531624744, + 0.09623229814933323, + 0.08174664901235043, + 0.11081911718888311, + 0.11102310514207447, + 0.1111041051969924, + 0.11108329076967229, + 0.11061376973431204, + 0.09620592927336903, + 0.08145227209865454, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] + ], + dtype=np.float32, ) SINK_MID_OUTPUT = np.array( [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [ - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - ], - [ - 9.99996103e-01, - 9.99994823e-01, - 9.99993550e-01, - 9.99930863e-01, - 9.99990782e-01, - 9.99984683e-01, - 9.99979000e-01, - 9.99997804e-01, - 9.99995985e-01, - 9.99994325e-01, - ], - [ - 9.99989344e-01, - 9.99988600e-01, - 9.99984099e-01, - 9.99930123e-01, - 9.99926598e-01, - 9.99824297e-01, - 9.99815032e-01, - 9.99991228e-01, - 9.99990881e-01, - 9.99988462e-01, - ], - [ - 9.99980787e-01, - 9.99979264e-01, - 9.99975828e-01, - 9.59669286e-01, - 9.59664779e-01, - 9.59656566e-01, - 9.59648332e-01, - 9.99983882e-01, - 9.99983038e-01, - 9.99980732e-01, - ], - [ - 9.99970181e-01, - 9.99969032e-01, - 9.99965730e-01, - 9.45197806e-01, - 9.45179593e-01, - 9.45163629e-01, - 9.45151458e-01, - 9.99973352e-01, - 9.99973254e-01, - 9.99971098e-01, - ], - [ - 9.99958608e-01, - 9.99957307e-01, - 9.99953444e-01, - 4.24743523e-01, - 4.24713305e-01, - 4.24694646e-01, - 4.24685271e-01, - 9.99960948e-01, - 9.99961829e-01, - 9.99960347e-01, - ], - [ - 9.99946675e-01, - 9.99945139e-01, - 9.99940312e-01, - 3.51353224e-01, - 3.51304003e-01, - 3.51268260e-01, - 3.51245366e-01, - 9.99947688e-01, - 9.99950165e-01, - 9.99949512e-01, - ], - [ - 9.99935877e-01, - 9.99934088e-01, - 9.99928982e-01, - 2.51197134e-01, - 2.51130273e-01, - 2.51080014e-01, - 2.51045852e-01, - 9.99936187e-01, - 9.99939716e-01, - 9.99940022e-01, - ], - [ - 9.99927846e-01, - 9.99925911e-01, - 9.99920188e-01, - 1.31550973e-01, - 1.31462736e-01, - 1.31394558e-01, - 1.31346069e-01, - 9.99927275e-01, - 9.99932142e-01, - 9.99933313e-01, - ], - [ - 9.99924204e-01, - 9.99922004e-01, - 9.99915767e-01, - 3.04861147e-04, - 1.95998056e-04, - 0.00000000e00, - 2.05182682e-05, - 9.99923115e-01, - 9.99928835e-01, - 9.99930535e-01, - ], - ] + 0.9999957448889315, + 0.9999781044114231, + 0.9999142422442185, + 0.999853253199584, + 0.9999918403054282, + 0.9999874855193227, + 0.9999513619364747, + 0.9999589247003497, + 0.9999861765528631, + 0.9999939213967494, + ], + [ + 0.9999918011366045, + 0.9999588498417253, + 0.9998388659316617, + 0.9998496524281603, + 0.9999154673258592, + 0.9997827845182361, + 0.9998160234579786, + 0.9999163964511287, + 0.9999743435786168, + 0.9999894752861168, + ], + [ + 0.9999883847481621, + 0.9999427334014465, + 0.9997703972600652, + 0.9853967608835997, + 0.9852517829915376, + 0.9853308520519438, + 0.9854102394414211, + 0.9998728503298413, + 0.9999642585978225, + 0.999986204909933, + ], + [ + 0.999985544721449, + 0.9999296195017368, + 0.9997066149628903, + 0.9753803016111353, + 0.9750688049429371, + 0.9749211929217173, + 0.9750052047129354, + 0.9998284130289159, + 0.9999558481338295, + 0.9999837966320273, + ], + [ + 0.9999832723447848, + 0.9999192263814408, + 0.9996472692076177, + 0.90541293509353, + 0.9049945536526819, + 0.9051142437853055, + 0.9057005861296792, + 0.9997839348839027, + 0.9999490318922627, + 0.9999820419085812, + ], + [ + 0.9999815409510937, + 0.9999113168889934, + 0.9995930143319085, + 0.8370025145062345, + 0.8358345435164332, + 0.8358231468627223, + 0.8369430449157075, + 0.9997408260265034, + 0.9999437526409107, + 0.9999808010740554, + ], + [ + 0.9999803198262347, + 0.9999057164296593, + 0.9995461103528891, + 0.7047260555380003, + 0.7023346743490383, + 0.7022946969603594, + 0.7045662738042475, + 0.9997017258131392, + 0.9999399744001316, + 0.9999799785302944, + ], + [ + 0.9999795785255197, + 0.9999022923125928, + 0.999510772973329, + 0.46283993237260707, + 0.4577365087549323, + 0.4571888733219068, + 0.4614967878524538, + 0.9996710272733927, + 0.9999376682163403, + 0.9999795067125865, + ], + [ + 0.9999792877553907, + 0.9999009179811408, + 0.9994950057121632, + 0.05049460567213739, + 0.030946131978013824, + 0.0, + 0.019224121648385283, + 0.9996568912408903, + 0.9999367861122628, + 0.9999793358521326, + ], + ], + dtype=np.float32, ) SINK_MIN_OUTPUT = np.array( [ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [ - 0.99997545, - 0.99996582, - 0.99995245, - 0.99856594, - 0.99898314, - 0.99777223, - 0.99394423, - 0.98588113, - 0.97283215, - 0.96096504, - ], - [ - 0.99993872, - 0.99993034, - 0.9998832, - 0.9986147, - 0.99848741, - 0.9972981, - 0.99723719, - 0.94157173, - 0.9369832, - 0.91964243, - ], - [ - 0.99990802, - 0.99989475, - 0.99986873, - 0.98610197, - 0.98610047, - 0.98609749, - 0.98609423, - 0.88741275, - 0.88112911, - 0.86349156, - ], - [ - 0.99988924, - 0.99988509, - 0.99988698, - 0.98234089, - 0.98233591, - 0.98233065, - 0.98232562, - 0.81475172, - 0.80865978, - 0.79033138, - ], - [ - 0.99988418, - 0.99988484, - 0.99988323, - 0.86796555, - 0.86795874, - 0.86795283, - 0.86794756, - 0.72418193, - 0.71847704, - 0.70022037, - ], - [ - 0.99988241, - 0.99988184, - 0.99988103, - 0.85528225, - 0.85527303, - 0.85526389, - 0.85525499, - 0.61716519, - 0.61026209, - 0.59503671, - ], - [ - 0.99988015, - 0.99987985, - 0.99987875, - 0.84258114, - 0.84257121, - 0.84256042, - 0.84254897, - 0.48997924, - 0.49083978, - 0.46891561, - ], - [ - 0.99987865, - 0.99987827, - 0.9998772, - 0.83279589, - 0.83278624, - 0.83277384, - 0.83275897, - 0.36345545, - 0.33690244, - 0.35696828, - ], - [ - 0.99987796, - 0.99987756, - 0.99987643, - 0.82873223, - 0.82872648, - 0.82871803, - 0.82870711, - 0.0, - 0.26106012, - 0.29978657, - ], - ] + 0.9999961997987318, + 0.9999801752476248, + 0.9999185667341594, + 0.9993115972922259, + 0.9999536433504382, + 0.9997590064584757, + 0.9963282396026231, + 0.9020645423682648, + 0.965641014946897, + 0.9847003633599846, + ], + [ + 0.9999926824858815, + 0.9999628275604145, + 0.9998472915971415, + 0.9992953054409239, + 0.9995550237000549, + 0.9972853256638443, + 0.9958871482234863, + 0.8006505271617617, + 0.9360757301263053, + 0.9734843475613124, + ], + [ + 0.9999896427490426, + 0.9999484707116104, + 0.9997841142091455, + 0.9321779021295554, + 0.9308591506422442, + 0.9299937642438358, + 0.9286536283468563, + 0.6964658886602826, + 0.9106656689679997, + 0.9652109119709528, + ], + [ + 0.9999871227708508, + 0.9999369646510842, + 0.9997276125796202, + 0.9006206490361908, + 0.8987968702587018, + 0.8965696900664386, + 0.8941507574801211, + 0.5892568658180841, + 0.8892240419729905, + 0.9590996257620853, + ], + [ + 0.9999851119906539, + 0.9999280075234918, + 0.9996788394671484, + 0.778755271203017, + 0.7763917808258874, + 0.7737517385551721, + 0.7707980517990098, + 0.4788014936236403, + 0.8715671104783401, + 0.954632732759503, + ], + [ + 0.9999835837292402, + 0.999921323618806, + 0.9996389455307461, + 0.7222961578407286, + 0.7186158832946955, + 0.7146983167265393, + 0.7105768254632475, + 0.3648911004360315, + 0.8575943501305144, + 0.9514642802768379, + ], + [ + 0.9999825081019064, + 0.999916683268467, + 0.9996093996776352, + 0.6713490686473397, + 0.6664914636518112, + 0.6613110504728309, + 0.6558325489984669, + 0.247299682539502, + 0.8473037957967624, + 0.9493580587294981, + ], + [ + 0.999981856118739, + 0.9999138938063622, + 0.9995907248497593, + 0.6331535096751639, + 0.6271637176135582, + 0.6206687804556549, + 0.6136262027168252, + 0.12576864809108962, + 0.8407892431959736, + 0.9481472656653798, + ], + [ + 0.9999816006081851, + 0.9999127861527936, + 0.9995832399159849, + 0.6133274396648696, + 0.6086364734302403, + 0.6034602717119345, + 0.5978473214165134, + 0.0, + 0.8382338778894218, + 0.9477082231321966, + ], + ], + dtype=np.float32, ) SINK_MASK_OUTPUT = np.array( [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.9047934405899283, 0.9936046284605553, 0.9448690902377527, 0.0, 0.0, 0.0, 0.8363773255131761], + [0.0, 0.0, 0.0, 0.90375200446097, 0.9434594475474036, 0.4716831449516178, 0.0, 0.0, 0.0, 0.7364197333910302], + [ + 0.0, + 0.0, + 0.0, + 0.09080438801405301, + 0.06774182873204163, + 0.038207095016625024, + 0.0, + 0.0, + 0.0, + 0.6745641479264269, + ], + [ + 0.0, + 0.0, + 0.0, + 0.01731082802870267, + 0.013540929458217351, + 0.007321202161532623, + 0.0, + 0.0, + 0.0, + 0.6341231654271253, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0006444251665178544, + 0.0005397129128756325, + 0.0003048384803626333, + 0.0, + 0.0, + 0.0, + 0.6070178708536365, + ], + [ + 0.0, + 0.0, + 0.0, + 5.406078586212675e-05, + 4.416783924970537e-05, + 2.4597362039020103e-05, + 0.0, + 0.0, + 0.0, + 0.5889413683184284, + ], [ - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - 1.00000000e00, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 2.86416400e-01, - 7.93271181e-01, - 5.81341234e-01, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 1.98395623e-01, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 2.66733297e-01, - 2.80741490e-01, - 4.14078784e-02, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 7.91676486e-04, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 1.86244537e-04, - 1.53413401e-04, - 7.85806495e-05, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 5.09797387e-06, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 9.62904581e-07, - 7.23946225e-07, - 3.68824440e-07, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 4.79525316e-08, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 1.50939343e-10, - 1.17724874e-10, - 6.21760843e-11, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 6.08922784e-10, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 2.57593754e-13, - 1.94066716e-13, - 9.83784370e-14, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 9.80828665e-12, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 4.22323494e-16, - 3.17556633e-16, - 1.60789400e-16, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 1.90789819e-13, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 7.72677888e-19, - 5.83029424e-19, - 2.95946659e-19, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 4.97038275e-15, - ], - [ - 2.71345908e-24, - 5.92006757e-24, - 2.25580089e-23, - 3.82601970e-18, - 3.82835349e-18, - 3.83302158e-18, - 3.84002606e-18, - 8.40760586e-16, - 1.83433696e-15, - 1.11629633e-15, - ], - ] + 0.0, + 0.0, + 0.0, + 4.39259327223233e-06, + 3.6050656774754658e-06, + 2.0127120155893425e-06, + 0.0, + 0.0, + 0.0, + 0.5774279920364456, + ], + [ + 0.0, + 0.0, + 0.0, + 4.0740501726718113e-07, + 3.374875487404489e-07, + 1.9113630985667455e-07, + 0.0, + 0.0, + 0.0, + 0.5709897726747111, + ], + [ + 3.2266922388030425e-17, + 1.801110982679718e-14, + 9.325899448306927e-12, + 3.913608442133728e-07, + 3.9581822403393465e-07, + 4.02383505118481e-07, + 4.14820241328287e-07, + 4.281640797396309e-06, + 0.0023900192231620593, + 0.5686882523793125, + ], + ], + dtype=np.float32, ) @@ -527,6 +491,21 @@ def setUp(self): self.input_img_torch = torch.from_numpy(TEST_INPUT).unsqueeze(0) # mock image (torch tensor) self.input_mask_torch = torch.from_numpy(TEST_MASK).unsqueeze(0) # mock mask (torch tensor) + self.real_input_img_paths = [ + os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "neck_input.png"), + os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "femur_input.png"), + ] + + self.real_result_npy_paths = [ + os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "neck_result.npy"), + os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "femur_result.npy"), + ] + + self.real_input_paramaters = [ + {"alpha": 2.0, "beta": 90, "gamma": 0.03}, + {"alpha": 2.0, "beta": 90, "gamma": 0.06}, + ] + def test_parameters(self): # Unknown mode with self.assertRaises(ValueError): @@ -683,6 +662,44 @@ def test_func(self): output = transform(self.input_img_torch, self.input_mask_torch) assert_allclose(output, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) + def test_against_official_code(self): + # This test is to compare the output of the transform with the official code + # The official code is available at: + # https://campar.in.tum.de/Main/AthanasiosKaramalisCode + + for input_img_path, result_npy_path, params in zip( + self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters + ): + input_img = np.array(Image.open(input_img_path)) + input_img = np.expand_dims(input_img, axis=0) + + result_img = np.load(result_npy_path) + + transform = UltrasoundConfidenceMapTransform(sink_mode="all", **params) + output = transform(input_img) + + assert_allclose(output, result_img, rtol=1e-4, atol=1e-4) + + def test_against_official_code_using_cg(self): + # This test is to compare the output of the transform with the official code + # The official code is available at: + # https://campar.in.tum.de/Main/AthanasiosKaramalisCode + + for input_img_path, result_npy_path, params in zip( + self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters + ): + input_img = np.array(Image.open(input_img_path)) + input_img = np.expand_dims(input_img, axis=0) + + result_img = np.load(result_npy_path) + + transform = UltrasoundConfidenceMapTransform( + sink_mode="all", use_cg=True, cg_tol=1.0e-6, cg_maxiter=300, **params + ) + output = transform(input_img) + + assert_allclose(output, result_img, rtol=1e-2, atol=1e-2) + if __name__ == "__main__": unittest.main() From ec060907f8985ec76fc8c74f39edde1f3a1416d4 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 22:43:45 +0100 Subject: [PATCH 25/32] DCO Remediation Commit for Eric Kerfoot <17726042+ericspod@users.noreply.github.com> I, Eric Kerfoot <17726042+ericspod@users.noreply.github.com>, hereby add my Signed-off-by to this commit: 15ff66397d3eedf696855b92c5b66ba0ab624471 Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> From d5da737ecce26d291fbcdb07c94f4651436120f7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 23:08:47 +0100 Subject: [PATCH 26/32] Further conflict resolutions and replacing regressed changes Signed-off-by: Eric Kerfoot --- Dockerfile | 3 - docs/source/networks.rst | 10 +++ monai/data/video_dataset.py | 2 +- monai/utils/misc.py | 4 +- pyproject.toml | 1 - tests/test_hilbert_transform.py | 121 +++++++------------------------- 6 files changed, 37 insertions(+), 104 deletions(-) diff --git a/Dockerfile b/Dockerfile index 10931222dd..8e255597d1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,9 +26,6 @@ RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ WORKDIR /opt/monai -# remove opencv-python before opencv-python-headless installation -RUN pip uninstall -y opencv && rm /usr/local/lib/python3.10/dist-packages/cv2 -r - # install full deps COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/ RUN cp /tmp/requirements.txt /tmp/req.bak \ diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..249375dfc1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -408,6 +408,11 @@ Layers .. autoclass:: LLTM :members: +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -486,6 +491,11 @@ Nets .. autoclass:: ResNet :members: +`ResNetFeatures` +~~~~~~~~~~~~~~~~ +.. autoclass:: ResNetFeatures + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py index 9ff23ebeff..031e85db26 100644 --- a/monai/data/video_dataset.py +++ b/monai/data/video_dataset.py @@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]: for codec, ext in all_codecs.items(): writer = cv2.VideoWriter() fname = os.path.join(tmp_dir, f"test{ext}") - fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined] + fourcc = cv2.VideoWriter_fourcc(*codec) noviderr = writer.open(fname, fourcc, 1, (10, 10)) if noviderr: codecs[codec] = ext diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 84dd3ad1f6..ab9fe259aa 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -889,11 +889,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: +def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: +def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/pyproject.toml b/pyproject.toml index 50d0b09672..53ca608d20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ exclude = "monai/bundle/__main__.py" [tool.ruff] line-length = 133 -lint.ignore-init-module-imports = true lint.ignore = ["F401", "E741"] [tool.pytype] diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 4c49aecd8b..a38b0a4956 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -19,11 +19,11 @@ from monai.networks.layers import HilbertTransform from monai.utils import OptionalImportError -from tests.utils import SkipIfModule, SkipIfNoModule, skip_if_no_cuda +from tests.utils import SkipIfModule, SkipIfNoModule def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) + x = np.fft.fft(input_datum.cpu().numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) new_dims_before = kwargs["axis"] @@ -44,19 +44,15 @@ def create_expected_numpy_output(input_datum, **kwargs): # CPU TEST DATA cpu_input_data = {} -cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) -cpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) -) -cpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu) - .unsqueeze(0) - .unsqueeze(0) -) -cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0) +cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu)[None, None] +cpu_input_data["2D"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None, None] +cpu_input_data["3D"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu +)[None, None] +cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None] cpu_input_data["2D 2CH"] = torch.as_tensor( np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu -).unsqueeze(0) +)[None] # SINGLE-CHANNEL CPU VALUE TESTS @@ -97,111 +93,42 @@ def create_expected_numpy_output(input_datum, **kwargs): 1e-5, # absolute tolerance ] +TEST_CASES_CPU = [ + TEST_CASE_1D_SINE_CPU, + TEST_CASE_2D_SINE_CPU, + TEST_CASE_3D_SINE_CPU, + TEST_CASE_1D_2CH_SINE_CPU, + TEST_CASE_2D_2CH_SINE_CPU, +] + # GPU TEST DATA if torch.cuda.is_available(): gpu = torch.device("cuda") - - gpu_input_data = {} - gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) - gpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) - ) - gpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu) - .unsqueeze(0) - .unsqueeze(0) - ) - gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0) - gpu_input_data["2D 2CH"] = torch.as_tensor( - np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu - ).unsqueeze(0) - - # SINGLE CHANNEL GPU VALUE TESTS - - TEST_CASE_1D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_3D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["3D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS - - TEST_CASE_1D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] + TEST_CASES_GPU = [[args, image.to(gpu), exp_data, atol] for args, image, exp_data, atol in TEST_CASES_CPU] +else: + TEST_CASES_GPU = [] # TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py @SkipIfNoModule("torch.fft") class TestHilbertTransformCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_1D_SINE_CPU, - TEST_CASE_2D_SINE_CPU, - TEST_CASE_3D_SINE_CPU, - TEST_CASE_1D_2CH_SINE_CPU, - TEST_CASE_2D_2CH_SINE_CPU, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).numpy() - np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) - -@skip_if_no_cuda -@SkipIfNoModule("torch.fft") -class TestHilbertTransformGPU(unittest.TestCase): - @parameterized.expand( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ], - skip_on_empty=True, - ) + @parameterized.expand(TEST_CASES_CPU + TEST_CASES_GPU) def test_value(self, arguments, image, expected_data, atol): result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).cpu().numpy() + result = np.squeeze(result.cpu().numpy()) np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) @SkipIfModule("torch.fft") class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10)) if __name__ == "__main__": unittest.main() + \ No newline at end of file From ff9073638fbf4e3a089d0bbb6e9e328afd1e3bd2 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 2 Jul 2024 23:16:44 +0100 Subject: [PATCH 27/32] Formatting Signed-off-by: Eric Kerfoot --- tests/test_hilbert_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index a38b0a4956..b91ba3f6b7 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -131,4 +131,3 @@ def test_no_fft_module_error(self): if __name__ == "__main__": unittest.main() - \ No newline at end of file From 78287810332e0a510375c48d6d55aac16e0f77b0 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 10 Jul 2024 02:20:33 +0100 Subject: [PATCH 28/32] Update to merge changes to SABlock, possibly resolving conflicts between versions Signed-off-by: Eric Kerfoot --- monai/networks/blocks/selfattention.py | 82 ++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..9d6526d70a 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,9 +11,12 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -32,6 +35,13 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + dim_head: int | None = None, + hidden_input_size: int | None = None, + causal: bool = False, + sequence_length: int | None = None, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None ) -> None: """ Args: @@ -40,6 +50,15 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762). + sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. + attention_dtype: cast attention operations to this dtype. """ @@ -51,22 +70,74 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if dim_head: + self.inner_dim = num_heads * dim_head + self.dim_head = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + self.inner_dim = hidden_size + self.dim_head = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + self.num_heads = num_heads - self.out_proj = nn.Linear(hidden_size, hidden_size) - self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.attention_dtype = attention_dtype + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) + + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html @@ -78,3 +149,4 @@ def forward(self, x): x = self.out_proj(x) x = self.drop_output(x) return x + \ No newline at end of file From 54e180dce7da98099bd5ce3519795442dfaf6e08 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 10 Jul 2024 03:10:17 +0100 Subject: [PATCH 29/32] Formatting Signed-off-by: Eric Kerfoot --- monai/networks/blocks/selfattention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9d6526d70a..9905e7d036 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -41,7 +41,7 @@ def __init__( sequence_length: int | None = None, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None + attention_dtype: Optional[torch.dtype] = None, ) -> None: """ Args: @@ -149,4 +149,3 @@ def forward(self, x): x = self.out_proj(x) x = self.drop_output(x) return x - \ No newline at end of file From b691308ee42baefb83de960e5d64f9506b7b2fa5 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Jul 2024 13:41:19 +0100 Subject: [PATCH 30/32] Typing fix Signed-off-by: Eric Kerfoot --- monai/networks/schedulers/scheduler.py | 6 ++++-- monai/utils/misc.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py index 17bb526abc..acdccc60de 100644 --- a/monai/networks/schedulers/scheduler.py +++ b/monai/networks/schedulers/scheduler.py @@ -196,8 +196,10 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) - sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) - sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim + ) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 96a59e1b35..d71f5d4f33 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -118,6 +118,7 @@ def star_zip_with(op, *vals): T = TypeVar("T") +NT = TypeVar("NT", np.ndarray, torch.Tensor) @overload @@ -907,11 +908,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: NT, ndim: int) -> NT: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: NT, ndim: int) -> NT: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] From 95c73ec9f0ea867a7fa453fb8de50484f7d0ad03 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 19 Jul 2024 14:00:22 +0100 Subject: [PATCH 31/32] Minor fix Signed-off-by: Eric Kerfoot --- monai/networks/nets/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 215e8d11a9..1af725abda 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -138,7 +138,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the attention blocks attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.concat( + new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( [ old_state_dict[f"{block}.attn.to_q.weight"], old_state_dict[f"{block}.attn.to_k.weight"], From e1d179035608bc07c0876334e18e457a97784d43 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 19 Jul 2024 15:07:44 +0100 Subject: [PATCH 32/32] Minor fix Signed-off-by: Eric Kerfoot --- monai/networks/nets/autoencoderkl.py | 4 ++-- monai/networks/nets/controlnet.py | 2 +- monai/networks/nets/diffusion_model_unet.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 17bb90d6f6..35d80e0565 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -670,7 +670,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the attention blocks attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.concat( + new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( [ old_state_dict[f"{block}.to_q.weight"], old_state_dict[f"{block}.to_k.weight"], @@ -678,7 +678,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: ], dim=0, ) - new_state_dict[f"{block}.attn.qkv.bias"] = torch.concat( + new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( [ old_state_dict[f"{block}.to_q.bias"], old_state_dict[f"{block}.to_k.bias"], diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index fe6746e017..ed3654733d 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -446,7 +446,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the attention blocks attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( [ old_state_dict[f"{block}.attn1.to_q.weight"], old_state_dict[f"{block}.attn1.to_k.weight"], diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index f995d20e54..8a9ac859a3 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1714,7 +1714,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the attention blocks attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( [ old_state_dict[f"{block}.attn1.to_q.weight"], old_state_dict[f"{block}.attn1.to_k.weight"],