From b3bcf1201d2b0bd80a1a6ac68c51973c80340608 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 15:50:18 +0000 Subject: [PATCH 01/19] add mor funcs Signed-off-by: Can-Zhao --- monai/apps/generation/__init__.py | 10 + monai/apps/generation/maisi/__init__.py | 10 + .../generation/maisi/networks/__init__.py | 10 + .../maisi/networks/autoencoderkl_maisi.py | 975 ++++++++++++++++++ .../maisi/networks/controlnet_maisi.py | 178 ++++ .../maisi/utils/morphological_ops.py | 170 +++ tests/test_morphological_ops.py | 130 +++ 7 files changed, 1483 insertions(+) create mode 100644 monai/apps/generation/__init__.py create mode 100644 monai/apps/generation/maisi/__init__.py create mode 100644 monai/apps/generation/maisi/networks/__init__.py create mode 100644 monai/apps/generation/maisi/networks/autoencoderkl_maisi.py create mode 100644 monai/apps/generation/maisi/networks/controlnet_maisi.py create mode 100644 monai/apps/generation/maisi/utils/morphological_ops.py create mode 100644 tests/test_morphological_ops.py diff --git a/monai/apps/generation/__init__.py b/monai/apps/generation/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/__init__.py b/monai/apps/generation/maisi/networks/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/networks/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py new file mode 100644 index 0000000000..533da32fa0 --- /dev/null +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -0,0 +1,975 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import gc +import logging +from typing import TYPE_CHECKING, Sequence, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.utils import optional_import +from monai.utils.type_conversion import convert_to_tensor + +AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") +AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") +ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") + + +if TYPE_CHECKING: + from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType +else: + AutoencoderKLType = cast(type, AutoencoderKL) + + +# Set up logging configuration +logger = logging.getLogger(__name__) + + +def _empty_cuda_cache(save_mem: bool) -> None: + if torch.cuda.is_available() and save_mem: + torch.cuda.empty_cache() + return + + +class MaisiGroupNorm3D(nn.GroupNorm): + """ + Custom 3D Group Normalization with optional print_info output. + + Args: + num_groups: Number of groups for the group norm. + num_channels: Number of channels for the group norm. + eps: Epsilon value for numerical stability. + affine: Whether to use learnable affine parameters, default to `True`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + norm_float16: bool = False, + print_info: bool = False, + save_mem: bool = True, + ): + super().__init__(num_groups, num_channels, eps, affine) + self.norm_float16 = norm_float16 + self.print_info = print_info + self.save_mem = save_mem + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.print_info: + logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") + + if len(input.shape) != 5: + raise ValueError("Expected a 5D tensor") + + param_n, param_c, param_d, param_h, param_w = input.shape + input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w) + + inputs = [] + for i in range(input.size(1)): + array = input[:, i : i + 1, ...].to(dtype=torch.float32) + mean = array.mean([2, 3, 4, 5], keepdim=True) + std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() + if self.norm_float16: + inputs.append(((array - mean) / std).to(dtype=torch.float16)) + else: + inputs.append((array - mean) / std) + + del input + _empty_cuda_cache(self.save_mem) + + input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs) + + input = input.view(param_n, param_c, param_d, param_h, param_w) + if self.affine: + input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1)) + + if self.print_info: + logger.info(f"MaisiGroupNorm3D with output size: {input.size()}") + + return input + + def _cat_inputs(self, inputs): + input_type = inputs[0].device.type + input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone() + inputs[0] = 0 + _empty_cuda_cache(self.save_mem) + + for k in range(len(inputs) - 1): + input = torch.cat((input, inputs[k + 1].cpu()), dim=1) + inputs[k + 1] = 0 + _empty_cuda_cache(self.save_mem) + gc.collect() + + if self.print_info: + logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.") + + return input.to("cuda", non_blocking=True) if input_type == "cuda" else input + + +class MaisiConvolution(nn.Module): + """ + Convolutional layer with optional print_info output and custom splitting mechanism. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Number of input channels. + out_channels: Number of output channels. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + Additional arguments for the convolution operation. + https://docs.monai.io/en/stable/networks.html#convolution + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_splits: int, + dim_split: int, + print_info: bool, + save_mem: bool = True, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, + adn_ordering: str = "NDA", + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, + dropout_dim: int = 1, + dilation: Sequence[int] | int = 1, + groups: int = 1, + bias: bool = True, + conv_only: bool = False, + is_transposed: bool = False, + padding: Sequence[int] | int | None = None, + output_padding: Sequence[int] | int | None = None, + ) -> None: + super().__init__() + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=strides, + kernel_size=kernel_size, + adn_ordering=adn_ordering, + act=act, + norm=norm, + dropout=dropout, + dropout_dim=dropout_dim, + dilation=dilation, + groups=groups, + bias=bias, + conv_only=conv_only, + is_transposed=is_transposed, + padding=padding, + output_padding=output_padding, + ) + + self.dim_split = dim_split + self.stride = strides[self.dim_split] if isinstance(strides, list) else strides + self.num_splits = num_splits + self.print_info = print_info + self.save_mem = save_mem + + def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]: + overlaps = [0] + [padding] * (self.num_splits - 1) + last_padding = x.size(self.dim_split + 2) % split_size + + slices = [slice(None)] * 5 + splits: list[torch.Tensor] = [] + for i in range(self.num_splits): + slices[self.dim_split + 2] = slice( + i * split_size - overlaps[i], + (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding), + ) + splits.append(x[tuple(slices)]) + + if self.print_info: + for j in range(len(splits)): + logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}") + + return splits + + def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor: + slices = [slice(None)] * 5 + for i in range(self.num_splits): + slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size) + outputs[i] = outputs[i][tuple(slices)] + + if self.print_info: + for i in range(self.num_splits): + logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}") + + if max(outputs[0].size()) < 500: + x = torch.cat(outputs, dim=self.dim_split + 2) + else: + x = outputs[0].clone().to("cpu", non_blocking=True) + outputs[0] = torch.Tensor(0) + _empty_cuda_cache(self.save_mem) + for k in range(len(outputs) - 1): + x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2) + outputs[k + 1] = torch.Tensor(0) + _empty_cuda_cache(self.save_mem) + gc.collect() + if self.print_info: + logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.") + + x = x.to("cuda", non_blocking=True) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.print_info: + logger.info(f"Number of splits: {self.num_splits}") + + # compute size of splits + l = x.size(self.dim_split + 2) + split_size = l // self.num_splits + + # update padding length if necessary + padding = 3 + if padding % self.stride > 0: + padding = (padding // self.stride + 1) * self.stride + if self.print_info: + logger.info(f"Padding size: {padding}") + + # split tensor into a list of tensors + splits = self._split_tensor(x, split_size, padding) + + del x + _empty_cuda_cache(self.save_mem) + + # convolution + outputs = [self.conv(split) for split in splits] + if self.print_info: + for j in range(len(outputs)): + logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") + + # update size of splits and padding length for output + split_size_out = split_size + padding_s = padding + non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0 + if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2: + split_size_out *= 2 + padding_s *= 2 + elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2: + split_size_out //= 2 + padding_s //= 2 + + # concatenate list of tensors + x = self._concatenate_tensors(outputs, split_size_out, padding_s) + + del outputs + _empty_cuda_cache(self.save_mem) + + return x + + +class MaisiUpsample(nn.Module): + """ + Convolution-based upsampling layer. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Number of input channels to the layer. + use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + use_convtranspose: bool, + num_splits: int, + dim_split: int, + print_info: bool, + save_mem: bool = True, + ) -> None: + super().__init__() + self.conv = MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2 if use_convtranspose else 1, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=use_convtranspose, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + self.use_convtranspose = use_convtranspose + self.save_mem = save_mem + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + x = self.conv(x) + x_tensor: torch.Tensor = convert_to_tensor(x) + return x_tensor + + x = F.interpolate(x, scale_factor=2.0, mode="trilinear") + _empty_cuda_cache(self.save_mem) + x = self.conv(x) + _empty_cuda_cache(self.save_mem) + + out_tensor: torch.Tensor = convert_to_tensor(x) + return out_tensor + + +class MaisiDownsample(nn.Module): + """ + Convolution-based downsampling layer. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Number of input channels. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_splits: int, + dim_split: int, + print_info: bool, + save_mem: bool = True, + ) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + self.conv = MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class MaisiResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Input channels to the layer. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. + out_channels: Number of output channels. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + num_splits: int, + dim_split: int, + norm_float16: bool = False, + print_info: bool = False, + save_mem: bool = True, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.save_mem = save_mem + + self.norm1 = MaisiGroupNorm3D( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + self.conv1 = MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + self.norm2 = MaisiGroupNorm3D( + num_groups=norm_num_groups, + num_channels=out_channels, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + self.conv2 = MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + + self.nin_shortcut = ( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + if self.in_channels != self.out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + _empty_cuda_cache(self.save_mem) + + h = F.silu(h) + _empty_cuda_cache(self.save_mem) + h = self.conv1(h) + _empty_cuda_cache(self.save_mem) + + h = self.norm2(h) + _empty_cuda_cache(self.save_mem) + + h = F.silu(h) + _empty_cuda_cache(self.save_mem) + h = self.conv2(h) + _empty_cuda_cache(self.save_mem) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + _empty_cuda_cache(self.save_mem) + + out = x + h + out_tensor: torch.Tensor = convert_to_tensor(out) + return out_tensor + + +class MaisiEncoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Number of input channels. + num_channels: Sequence of block output channels. + out_channels: Number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: Number of residual blocks (see ResBlock) per level. + norm_num_groups: Number of groups for the group norm layers. + norm_eps: Epsilon for the normalization. + attention_levels: Indicate which level from num_channels contain an attention block. + with_nonlocal_attn: If True, use non-local attention block. + use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + num_splits: int, + dim_split: int, + norm_float16: bool = False, + print_info: bool = False, + save_mem: bool = True, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + + # Check if attention_levels and num_channels have the same size + if len(attention_levels) != len(num_channels): + raise ValueError("attention_levels and num_channels must have the same size") + + # Check if num_res_blocks and num_channels have the same size + if len(num_res_blocks) != len(num_channels): + raise ValueError("num_res_blocks and num_channels must have the same size") + + self.save_mem = save_mem + + blocks: list[nn.Module] = [] + + blocks.append( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(num_res_blocks[i]): + blocks.append( + MaisiResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + num_splits=num_splits, + dim_split=dim_split, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + MaisiDownsample( + spatial_dims=spatial_dims, + in_channels=input_channel, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + if with_nonlocal_attn: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + MaisiGroupNorm3D( + num_groups=norm_num_groups, + num_channels=num_channels[-1], + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + ) + blocks.append( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + _empty_cuda_cache(self.save_mem) + return x + + +class MaisiDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + num_channels: Sequence of block output channels. + in_channels: Number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks (see ResBlock) per level. + norm_num_groups: Number of groups for the group norm layers. + norm_eps: Epsilon for the normalization. + attention_levels: Indicate which level from num_channels contain an attention block. + with_nonlocal_attn: If True, use non-local attention block. + use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + num_splits: int, + dim_split: int, + norm_float16: bool = False, + print_info: bool = False, + save_mem: bool = True, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.print_info = print_info + self.save_mem = save_mem + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks: list[nn.Module] = [] + + blocks.append( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + if with_nonlocal_attn: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + MaisiResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + num_splits=num_splits, + dim_split=dim_split, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + MaisiUpsample( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + use_convtranspose=use_convtranspose, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + blocks.append( + MaisiGroupNorm3D( + num_groups=norm_num_groups, + num_channels=block_in_ch, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + ) + blocks.append( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + num_splits=num_splits, + dim_split=dim_split, + print_info=print_info, + save_mem=save_mem, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + _empty_cuda_cache(self.save_mem) + return x + + +class AutoencoderKlMaisi(AutoencoderKLType): + """ + AutoencoderKL with custom MaisiEncoder and MaisiDecoder. + + Args: + spatial_dims: Number of spatial dimensions (1D, 2D, 3D). + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks per level. + num_channels: Sequence of block output channels. + attention_levels: Indicate which level from num_channels contain an attention block. + latent_channels: Number of channels in the latent space. + norm_num_groups: Number of groups for the group norm layers. + norm_eps: Epsilon for the normalization. + with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. + with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. + use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. + use_checkpointing: If True, use activation checkpointing. + use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + num_channels: Sequence[int], + attention_levels: Sequence[bool], + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = False, + with_decoder_nonlocal_attn: bool = False, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, + num_splits: int = 16, + dim_split: int = 0, + norm_float16: bool = False, + print_info: bool = False, + save_mem: bool = True, + ) -> None: + super().__init__( + spatial_dims, + in_channels, + out_channels, + num_res_blocks, + num_channels, + attention_levels, + latent_channels, + norm_num_groups, + norm_eps, + with_encoder_nonlocal_attn, + with_decoder_nonlocal_attn, + use_flash_attention, + use_checkpointing, + use_convtranspose, + ) + + self.encoder = MaisiEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + num_splits=num_splits, + dim_split=dim_split, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) + + self.decoder = MaisiDecoder( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + num_splits=num_splits, + dim_split=dim_split, + norm_float16=norm_float16, + print_info=print_info, + save_mem=save_mem, + ) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py new file mode 100644 index 0000000000..3641124b7d --- /dev/null +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -0,0 +1,178 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, cast + +import torch + +from monai.utils import optional_import + +ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") +get_timestep_embedding, has_get_timestep_embedding = optional_import( + "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +) + +if TYPE_CHECKING: + from generative.networks.nets.controlnet import ControlNet as ControlNetType +else: + ControlNetType = cast(type, ControlNet) + + +class ControlNetMaisi(ControlNetType): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + use_checkpointing: if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + use_checkpointing: bool = True, + ) -> None: + super().__init__( + spatial_dims, + in_channels, + num_res_blocks, + num_channels, + attention_levels, + norm_num_groups, + norm_eps, + resblock_updown, + num_head_channels, + with_conditioning, + transformer_num_layers, + cross_attention_dim, + num_class_embeds, + upcast_attention, + use_flash_attention, + conditioning_embedding_in_channels, + conditioning_embedding_num_channels, + ) + self.use_checkpointing = use_checkpointing + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) + h = self._apply_initial_convolution(x) + if self.use_checkpointing: + controlnet_cond = torch.utils.checkpoint.checkpoint( + self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False + ) + else: + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + h += controlnet_cond + down_block_res_samples, h = self._apply_down_blocks(emb, context, h) + h = self._apply_mid_block(emb, context, h) + down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples) + # scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample + + def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + return emb + + def _apply_initial_convolution(self, x): + # 3. initial convolution + h = self.conv_in(x) + return h + + def _apply_down_blocks(self, emb, context, h): + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + return down_block_res_samples, h + + def _apply_mid_block(self, emb, context, h): + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + return h + + def _apply_controlnet_blocks(self, h, down_block_res_samples): + # 6. Control net blocks + controlnet_down_block_res_samples = [] + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) + + mid_block_res_sample = self.controlnet_mid_block(h) + + return controlnet_down_block_res_samples, mid_block_res_sample diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py new file mode 100644 index 0000000000..88eaca9dd7 --- /dev/null +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -0,0 +1,170 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +import torch +import torch.nn.functional as F +from torch import Tensor + +from monai.utils import ( + convert_data_type, + convert_to_dst_type, + ensure_tuple_rep, +) +from monai.config import NdarrayOrTensor +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type + +def erode(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: + """ + Erode 2D/3D binary mask. + + Args: + mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + filter_size: erosion filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + + Return: + eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + + Example: + + .. code-block:: python + + # define a naive network + mask = torch.zeros(3,2,3,3,3) + mask[:,:,1,1,1] = 1.0 + filter_size = 3 + erode_result = morphological_ops.erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) + dilate_result = morphological_ops.dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) + + """ + mask_t, *_ = convert_data_type(mask, torch.Tensor) + res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value) + res_mask: NdarrayOrTensor + res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) + return res_mask + +def dilate(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: + """ + Dilate 2D/3D binary mask. + + Args: + mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + filter_size: dilation filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + + Return: + dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + + Example: + + .. code-block:: python + + # define a naive network + mask = torch.zeros(3,2,3,3,3) + mask[:,:,1,1,1] = 1.0 + filter_size = 3 + erode_result = morphological_ops.erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) + dilate_result = morphological_ops.dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) + """ + mask_t, *_ = convert_data_type(mask, torch.Tensor) + res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value) + res_mask: NdarrayOrTensor + res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) + return res_mask + +def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int|Sequence[int], pad_value: float) -> Tensor: + """ + Get morphological filter result for 2D/3D mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. + filter_size: morphological filter size, has to be odd numbers. + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. + + Return: + Morphological filter result mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + """ + spatial_dims = len(mask_t.shape)-2 + if spatial_dims not in [2,3]: + raise ValueError(f"spatial_dims must be either 2 or 3, yet got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}.") + + # Define the structuring element + filter_size = ensure_tuple_rep(filter_size, spatial_dims) + if any(size % 2 == 0 for size in filter_size): + raise ValueError(f"All dimensions in filter_size must be odd numbers, yet got {filter_size}.") + + filter_shape = [mask_t.shape[1],mask_t.shape[1]]+list(filter_size) + structuring_element = torch.ones(filter_shape).to( + mask_t.device + ) + + # Pad the input tensor to handle border pixels + # Calculate padding size + pad_size = [] + for size in filter_size: + pad_size.extend([size // 2, size // 2]) + + input_padded = F.pad( + mask_t.float(), + pad_size, + mode="constant", + value=pad_value, + ) + + # Apply filter operation + if spatial_dims == 2: + output = F.conv2d(input_padded, structuring_element, padding=0)/torch.sum(structuring_element[0,...]) + if spatial_dims == 3: + output = F.conv3d(input_padded, structuring_element, padding=0)/torch.sum(structuring_element[0,...]) + + return output + +def erode_t(mask_t: Tensor, filter_size: int|Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: + """ + Erode 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. + filter_size: erosion filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + + Return: + eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + """ + + output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == 1.0, 1.0, 0.0) + + return output + + +def dilate_t(mask_t: Tensor, filter_size: int|Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: + """ + Dilate 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. + filter_size: dilation filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + + Return: + dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + """ + output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output > 0, 1.0, 0.0) + + return output diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py new file mode 100644 index 0000000000..a5d59c6128 --- /dev/null +++ b/tests/test_morphological_ops.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +import torch +from tests.utils import TEST_NDARRAYS, assert_allclose + +from parameterized import parameterized + +from monai.apps.generation.maisi.utils import morphological_ops + +TESTS_SHAPE = [] +for p in TEST_NDARRAYS: + mask = torch.zeros(1,1,5,5,5) + filter_size = 3 + TESTS_SHAPE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + [1,1,5,5,5], + ] + ) + mask = torch.zeros(3,2,5,5,5) + filter_size = 5 + TESTS_SHAPE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + [3,2,5,5,5], + ] + ) + mask = torch.zeros(1,1,1,1,1) + filter_size = 5 + TESTS_SHAPE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + [1,1,1,1,1], + ] + ) + +TESTS_VALUE_T = [] +mask = torch.ones(3,2,3,3,3) +filter_size = 3 +TESTS_VALUE_T.append( + [ + {"mask": mask, "filter_size": filter_size, "pad_value":1.0}, + torch.ones(3,2,3,3,3) + ] +) +mask = torch.zeros(3,2,3,3,3) +filter_size = 3 +TESTS_VALUE_T.append( + [ + {"mask": mask, "filter_size": filter_size, "pad_value":0.0}, + torch.zeros(3,2,3,3,3) + ] +) + +TESTS_VALUE = [] +for p in TEST_NDARRAYS: + mask = torch.zeros(3,2,5,5,5) + filter_size = 3 + TESTS_VALUE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + p(torch.zeros(3,2,5,5,5)), + p(torch.zeros(3,2,5,5,5)), + ] + ) + mask = torch.ones(1,1,3,3,3) + filter_size = 3 + TESTS_VALUE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + p(torch.ones(1,1,3,3,3)), + p(torch.ones(1,1,3,3,3)), + ] + ) + mask = torch.ones(1,2,3,3,3) + filter_size = 3 + TESTS_VALUE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + p(torch.ones(1,2,3,3,3)), + p(torch.ones(1,2,3,3,3)), + ] + ) + mask = torch.zeros(3,2,3,3,3) + mask[:,:,1,1,1] = 1.0 + filter_size = 3 + TESTS_VALUE.append( + [ + {"mask": p(mask), "filter_size": filter_size}, + p(torch.zeros(3,2,3,3,3)), + p(torch.ones(3,2,3,3,3)), + ] + ) + + +class TestMorph(unittest.TestCase): + + @parameterized.expand(TESTS_SHAPE) + def test_shape(self, input_data, expected_result): + result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) + assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0) + + @parameterized.expand(TESTS_VALUE_T) + def test_value_t(self, input_data, expected_result): + result1 = morphological_ops.get_morphological_filter_result_t(input_data["mask"],input_data["filter_size"],input_data["pad_value"]) + # result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) + # assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) + assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0) + + @parameterized.expand(TESTS_VALUE) + def test_value(self, input_data, expected_erode_result, expected_dilate_result): + result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) + assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) + result2 = morphological_ops.dilate(input_data["mask"],input_data["filter_size"]) + assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0) + +if __name__ == "__main__": + unittest.main() From 7a8a0ef02b52e5cbdcfd31bf60738bc083cdd7fa Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 16:03:09 +0000 Subject: [PATCH 02/19] black Signed-off-by: Can-Zhao --- .../maisi/utils/morphological_ops.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 88eaca9dd7..9436f5c1b6 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -23,7 +23,8 @@ from monai.config import NdarrayOrTensor from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -def erode(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: + +def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: """ Erode 2D/3D binary mask. @@ -53,7 +54,8 @@ def erode(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) return res_mask -def dilate(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: + +def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: """ Dilate 2D/3D binary mask. @@ -82,7 +84,8 @@ def dilate(mask: NdarrayOrTensor, filter_size: int|Sequence[int] = 3, pad_value: res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) return res_mask -def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int|Sequence[int], pad_value: float) -> Tensor: + +def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor: """ Get morphological filter result for 2D/3D mask with data type as torch tensor. @@ -94,19 +97,19 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int|Sequence[ Return: Morphological filter result mask, [N,C,M,N] or [N,C,M,N,P] torch tensor """ - spatial_dims = len(mask_t.shape)-2 - if spatial_dims not in [2,3]: - raise ValueError(f"spatial_dims must be either 2 or 3, yet got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}.") + spatial_dims = len(mask_t.shape) - 2 + if spatial_dims not in [2, 3]: + raise ValueError( + f"spatial_dims must be either 2 or 3, yet got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." + ) # Define the structuring element filter_size = ensure_tuple_rep(filter_size, spatial_dims) if any(size % 2 == 0 for size in filter_size): raise ValueError(f"All dimensions in filter_size must be odd numbers, yet got {filter_size}.") - filter_shape = [mask_t.shape[1],mask_t.shape[1]]+list(filter_size) - structuring_element = torch.ones(filter_shape).to( - mask_t.device - ) + filter_shape = [mask_t.shape[1], mask_t.shape[1]] + list(filter_size) + structuring_element = torch.ones(filter_shape).to(mask_t.device) # Pad the input tensor to handle border pixels # Calculate padding size @@ -123,13 +126,14 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int|Sequence[ # Apply filter operation if spatial_dims == 2: - output = F.conv2d(input_padded, structuring_element, padding=0)/torch.sum(structuring_element[0,...]) + output = F.conv2d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) if spatial_dims == 3: - output = F.conv3d(input_padded, structuring_element, padding=0)/torch.sum(structuring_element[0,...]) + output = F.conv3d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) return output -def erode_t(mask_t: Tensor, filter_size: int|Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: + +def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: """ Erode 2D/3D binary mask with data type as torch tensor. @@ -150,7 +154,7 @@ def erode_t(mask_t: Tensor, filter_size: int|Sequence[int] = 3, pad_value: float return output -def dilate_t(mask_t: Tensor, filter_size: int|Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: +def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: """ Dilate 2D/3D binary mask with data type as torch tensor. From 017d8cc6a9f277e3e959b4656c7bdff742f6c522 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 16:06:23 +0000 Subject: [PATCH 03/19] reformat Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 8 +++----- tests/test_morphological_ops.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 9436f5c1b6..011e0e6644 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -9,18 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Sequence import torch import torch.nn.functional as F from torch import Tensor -from monai.utils import ( - convert_data_type, - convert_to_dst_type, - ensure_tuple_rep, -) from monai.config import NdarrayOrTensor +from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep from monai.utils.type_conversion import convert_data_type, convert_to_dst_type diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index a5d59c6128..a425859713 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -12,12 +12,12 @@ from __future__ import annotations import unittest -import torch -from tests.utils import TEST_NDARRAYS, assert_allclose +import torch from parameterized import parameterized from monai.apps.generation.maisi.utils import morphological_ops +from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_SHAPE = [] for p in TEST_NDARRAYS: From 7c22897f332fa956767268006ca155b58c631eab Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 16:22:38 +0000 Subject: [PATCH 04/19] reformat Signed-off-by: Can-Zhao --- .../maisi/utils/morphological_ops.py | 8 +- tests/test_morphological_ops.py | 90 ++++++------------- 2 files changed, 27 insertions(+), 71 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 011e0e6644..92dd7f64b3 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -19,7 +19,6 @@ from monai.config import NdarrayOrTensor from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: @@ -115,12 +114,7 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc for size in filter_size: pad_size.extend([size // 2, size // 2]) - input_padded = F.pad( - mask_t.float(), - pad_size, - mode="constant", - value=pad_value, - ) + input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value) # Apply filter operation if spatial_dims == 2: diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index a425859713..30038c6caa 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -21,87 +21,46 @@ TESTS_SHAPE = [] for p in TEST_NDARRAYS: - mask = torch.zeros(1,1,5,5,5) + mask = torch.zeros(1, 1, 5, 5, 5) filter_size = 3 - TESTS_SHAPE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - [1,1,5,5,5], - ] - ) - mask = torch.zeros(3,2,5,5,5) + TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]]) + mask = torch.zeros(3, 2, 5, 5, 5) filter_size = 5 - TESTS_SHAPE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - [3,2,5,5,5], - ] - ) - mask = torch.zeros(1,1,1,1,1) + TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]]) + mask = torch.zeros(1, 1, 1, 1, 1) filter_size = 5 - TESTS_SHAPE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - [1,1,1,1,1], - ] - ) + TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]]) TESTS_VALUE_T = [] -mask = torch.ones(3,2,3,3,3) +mask = torch.ones(3, 2, 3, 3, 3) filter_size = 3 -TESTS_VALUE_T.append( - [ - {"mask": mask, "filter_size": filter_size, "pad_value":1.0}, - torch.ones(3,2,3,3,3) - ] -) -mask = torch.zeros(3,2,3,3,3) +TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)]) +mask = torch.zeros(3, 2, 3, 3, 3) filter_size = 3 -TESTS_VALUE_T.append( - [ - {"mask": mask, "filter_size": filter_size, "pad_value":0.0}, - torch.zeros(3,2,3,3,3) - ] -) +TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)]) TESTS_VALUE = [] for p in TEST_NDARRAYS: - mask = torch.zeros(3,2,5,5,5) + mask = torch.zeros(3, 2, 5, 5, 5) filter_size = 3 TESTS_VALUE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - p(torch.zeros(3,2,5,5,5)), - p(torch.zeros(3,2,5,5,5)), - ] + [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))] ) - mask = torch.ones(1,1,3,3,3) + mask = torch.ones(1, 1, 3, 3, 3) filter_size = 3 TESTS_VALUE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - p(torch.ones(1,1,3,3,3)), - p(torch.ones(1,1,3,3,3)), - ] + [{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))] ) - mask = torch.ones(1,2,3,3,3) + mask = torch.ones(1, 2, 3, 3, 3) filter_size = 3 TESTS_VALUE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - p(torch.ones(1,2,3,3,3)), - p(torch.ones(1,2,3,3,3)), - ] + [{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))] ) - mask = torch.zeros(3,2,3,3,3) - mask[:,:,1,1,1] = 1.0 + mask = torch.zeros(3, 2, 3, 3, 3) + mask[:, :, 1, 1, 1] = 1.0 filter_size = 3 TESTS_VALUE.append( - [ - {"mask": p(mask), "filter_size": filter_size}, - p(torch.zeros(3,2,3,3,3)), - p(torch.ones(3,2,3,3,3)), - ] + [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))] ) @@ -109,22 +68,25 @@ class TestMorph(unittest.TestCase): @parameterized.expand(TESTS_SHAPE) def test_shape(self, input_data, expected_result): - result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) + result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"]) assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0) @parameterized.expand(TESTS_VALUE_T) def test_value_t(self, input_data, expected_result): - result1 = morphological_ops.get_morphological_filter_result_t(input_data["mask"],input_data["filter_size"],input_data["pad_value"]) + result1 = morphological_ops.get_morphological_filter_result_t( + input_data["mask"], input_data["filter_size"], input_data["pad_value"] + ) # result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) # assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0) @parameterized.expand(TESTS_VALUE) def test_value(self, input_data, expected_erode_result, expected_dilate_result): - result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) + result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"]) assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) - result2 = morphological_ops.dilate(input_data["mask"],input_data["filter_size"]) + result2 = morphological_ops.dilate(input_data["mask"], input_data["filter_size"]) assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0) + if __name__ == "__main__": unittest.main() From 5899029ded83f58ce92155414a6d5e6e435852f2 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 18:23:27 +0000 Subject: [PATCH 05/19] typo Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 92dd7f64b3..4c6daee6a1 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -37,7 +37,7 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value .. code-block:: python - # define a naive network + # define a naive mask mask = torch.zeros(3,2,3,3,3) mask[:,:,1,1,1] = 1.0 filter_size = 3 @@ -68,7 +68,7 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu .. code-block:: python - # define a naive network + # define a naive mask mask = torch.zeros(3,2,3,3,3) mask[:,:,1,1,1] = 1.0 filter_size = 3 From abb26735141106233afee31f2d9286e90a1b6392 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 1 Jul 2024 18:24:29 +0000 Subject: [PATCH 06/19] update docstring Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 4c6daee6a1..f9b3678f57 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -41,8 +41,8 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value mask = torch.zeros(3,2,3,3,3) mask[:,:,1,1,1] = 1.0 filter_size = 3 - erode_result = morphological_ops.erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) - dilate_result = morphological_ops.dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) + erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) + dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) """ mask_t, *_ = convert_data_type(mask, torch.Tensor) @@ -72,8 +72,8 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu mask = torch.zeros(3,2,3,3,3) mask[:,:,1,1,1] = 1.0 filter_size = 3 - erode_result = morphological_ops.erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) - dilate_result = morphological_ops.dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) + erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) + dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) """ mask_t, *_ = convert_data_type(mask, torch.Tensor) res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value) From f767af070c50e932e590e6bc1f456569b6eda2f8 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 04:36:31 +0000 Subject: [PATCH 07/19] add eps Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index f9b3678f57..b31064b9b5 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -141,7 +141,7 @@ def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: flo output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) # Set output values based on the minimum value within the structuring element - output = torch.where(output == 1.0, 1.0, 0.0) + output = torch.where(abs(output - 1.0) < 1e-7, 1.0, 0.0) return output From 2e61c44802bc473dc8e5d695c66b37f6b3181a34 Mon Sep 17 00:00:00 2001 From: Can Zhao <69829124+Can-Zhao@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:50:08 -0700 Subject: [PATCH 08/19] Update monai/apps/generation/maisi/utils/morphological_ops.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Can Zhao <69829124+Can-Zhao@users.noreply.github.com> --- monai/apps/generation/maisi/utils/morphological_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index f9b3678f57..d322e7f93a 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -84,7 +84,7 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor: """ - Get morphological filter result for 2D/3D mask with data type as torch tensor. + Apply a morphological filter to a 2D/3D binary mask tensor. Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. From b8879108ed234fe9cec1949d03255579e30fbc15 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 04:51:10 +0000 Subject: [PATCH 09/19] update docstring Signed-off-by: Can-Zhao --- .../maisi/utils/morphological_ops.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index b31064b9b5..57d0770dba 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -31,7 +31,7 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. Return: - eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + eroded mask, same shape and data type as input. Example: @@ -62,7 +62,7 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. Return: - dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. + dilated mask, same shape and data type as input. Example: @@ -84,7 +84,7 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor: """ - Get morphological filter result for 2D/3D mask with data type as torch tensor. + Apply a morphological filter to a 2D/3D binary mask tensor. Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. @@ -92,21 +92,21 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Return: - Morphological filter result mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + Tensor: Morphological filter result mask, same shape as input. """ spatial_dims = len(mask_t.shape) - 2 if spatial_dims not in [2, 3]: raise ValueError( - f"spatial_dims must be either 2 or 3, yet got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." + f"spatial_dims must be either 2 or 3, got spatial_dims={spatial_dims} for mask tensor with shape of { + mask_t.shape}." ) # Define the structuring element filter_size = ensure_tuple_rep(filter_size, spatial_dims) if any(size % 2 == 0 for size in filter_size): - raise ValueError(f"All dimensions in filter_size must be odd numbers, yet got {filter_size}.") + raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.") - filter_shape = [mask_t.shape[1], mask_t.shape[1]] + list(filter_size) - structuring_element = torch.ones(filter_shape).to(mask_t.device) + structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device) # Pad the input tensor to handle border pixels # Calculate padding size @@ -117,10 +117,8 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value) # Apply filter operation - if spatial_dims == 2: - output = F.conv2d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) - if spatial_dims == 3: - output = F.conv3d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) + conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d + output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) return output @@ -135,13 +133,13 @@ def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: flo pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. Return: - eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + Tensor: eroded mask, same shape as input. """ output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) # Set output values based on the minimum value within the structuring element - output = torch.where(abs(output - 1.0) < 1e-7, 1.0, 0.0) + output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0) return output @@ -156,7 +154,7 @@ def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: fl pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. Return: - dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor + Tensor: dilated mask, same shape as input. """ output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) From 365e48d0c960890d904f1c93b91300e60b7b747a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 04:55:19 +0000 Subject: [PATCH 10/19] add 2d unit test Signed-off-by: Can-Zhao --- tests/test_morphological_ops.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index 30038c6caa..ba9b440a53 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -30,14 +30,19 @@ mask = torch.zeros(1, 1, 1, 1, 1) filter_size = 5 TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]]) + mask = torch.zeros(1, 1, 1, 1) + filter_size = 5 + TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]]) TESTS_VALUE_T = [] mask = torch.ones(3, 2, 3, 3, 3) filter_size = 3 TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)]) -mask = torch.zeros(3, 2, 3, 3, 3) -filter_size = 3 TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)]) +mask = torch.ones(3, 2, 3, 3) +filter_size = 3 +TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)]) +TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)]) TESTS_VALUE = [] for p in TEST_NDARRAYS: @@ -62,6 +67,12 @@ TESTS_VALUE.append( [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))] ) + mask = torch.zeros(3, 2, 3, 3) + mask[:, :, 1, 1, 1] = 1.0 + filter_size = 3 + TESTS_VALUE.append( + [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))] + ) class TestMorph(unittest.TestCase): From 971445e09391e40450140d7f43a64526034ddb7b Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:03:50 +0000 Subject: [PATCH 11/19] add 2d unit test Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 4 +--- tests/test_morphological_ops.py | 7 ++++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 57d0770dba..5ac4add3c2 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -110,9 +110,7 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc # Pad the input tensor to handle border pixels # Calculate padding size - pad_size = [] - for size in filter_size: - pad_size.extend([size // 2, size // 2]) + pad_size = [size // 2 for size in filter_size for _ in range(2)] input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value) diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index ba9b440a53..43ae1eb8cc 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -35,13 +35,14 @@ TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]]) TESTS_VALUE_T = [] -mask = torch.ones(3, 2, 3, 3, 3) filter_size = 3 +mask = torch.ones(3, 2, 3, 3, 3) TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)]) +mask = torch.zeros(3, 2, 3, 3, 3) TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)]) mask = torch.ones(3, 2, 3, 3) -filter_size = 3 TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)]) +mask = torch.zeros(3, 2, 3, 3) TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)]) TESTS_VALUE = [] @@ -68,7 +69,7 @@ [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))] ) mask = torch.zeros(3, 2, 3, 3) - mask[:, :, 1, 1, 1] = 1.0 + mask[:, :, 1, 1] = 1.0 filter_size = 3 TESTS_VALUE.append( [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))] From 0e36997e2f6f08a820e7b461852ea49f36ac372f Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:10:48 +0000 Subject: [PATCH 12/19] shorten docstring line Signed-off-by: Can-Zhao --- .../maisi/utils/morphological_ops.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 5ac4add3c2..2df7994e70 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -28,7 +28,9 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value Args: mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. filter_size: erosion filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. Return: eroded mask, same shape and data type as input. @@ -41,9 +43,8 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value mask = torch.zeros(3,2,3,3,3) mask[:,:,1,1,1] = 1.0 filter_size = 3 - erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) - dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) - + erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3) + dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3) """ mask_t, *_ = convert_data_type(mask, torch.Tensor) res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value) @@ -52,6 +53,7 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value return res_mask + def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: """ Dilate 2D/3D binary mask. @@ -59,7 +61,9 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu Args: mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. filter_size: dilation filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. Return: dilated mask, same shape and data type as input. @@ -89,7 +93,8 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: morphological filter size, has to be odd numbers. - pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Return: Tensor: Morphological filter result mask, same shape as input. @@ -97,8 +102,7 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc spatial_dims = len(mask_t.shape) - 2 if spatial_dims not in [2, 3]: raise ValueError( - f"spatial_dims must be either 2 or 3, got spatial_dims={spatial_dims} for mask tensor with shape of { - mask_t.shape}." + f"spatial_dims must be either 2 or 3, got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." ) # Define the structuring element @@ -128,7 +132,9 @@ def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: flo Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: erosion filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. Return: Tensor: eroded mask, same shape as input. @@ -149,7 +155,9 @@ def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: fl Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: dilation filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. Return: Tensor: dilated mask, same shape as input. From 203f4235554377438000538969e3bd99909e1209 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 05:11:15 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../maisi/utils/morphological_ops.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index 2df7994e70..be8c27d97c 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -28,8 +28,8 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value Args: mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. filter_size: erosion filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering - to keep the output with the same size as input. Usually use default value + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value and not changed. Return: @@ -61,8 +61,8 @@ def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_valu Args: mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. filter_size: dilation filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering - to keep the output with the same size as input. Usually use default value + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value and not changed. Return: @@ -93,7 +93,7 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: morphological filter size, has to be odd numbers. - pad_value: the filled value for padding. We need to pad the input before filtering + pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Return: @@ -132,8 +132,8 @@ def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: flo Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: erosion filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering - to keep the output with the same size as input. Usually use default value + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value and not changed. Return: @@ -155,8 +155,8 @@ def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: fl Args: mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. filter_size: dilation filter size, has to be odd numbers, default to be 3. - pad_value: the filled value for padding. We need to pad the input before filtering - to keep the output with the same size as input. Usually use default value + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value and not changed. Return: From 44b990cb5219abfa3d69d31c2aab81dd61088c55 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:21:16 +0000 Subject: [PATCH 14/19] reformat Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/morphological_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/apps/generation/maisi/utils/morphological_ops.py index be8c27d97c..14786d60a2 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/apps/generation/maisi/utils/morphological_ops.py @@ -53,7 +53,6 @@ def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value return res_mask - def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: """ Dilate 2D/3D binary mask. @@ -102,7 +101,8 @@ def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequenc spatial_dims = len(mask_t.shape) - 2 if spatial_dims not in [2, 3]: raise ValueError( - f"spatial_dims must be either 2 or 3, got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." + f"spatial_dims must be either 2 or 3, " + f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." ) # Define the structuring element From 450c4864350f8fb7bad01e5735c1adb434b1f5e4 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:25:25 +0000 Subject: [PATCH 15/19] update doc Signed-off-by: Can-Zhao --- docs/source/apps.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 7fa7b9e9ff..6e7a83893c 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -261,3 +261,11 @@ FastMRIReader .. autoclass:: monai.apps.nnunet.nnUNetV2Runner :members: + +`Generative AI` +----------- + +`MAISI Utilities` +~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.apps.generation/maisi/utils/morphological_ops + :members: From ef7c5496476f989172db1e1051188981126c5b7a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:29:37 +0000 Subject: [PATCH 16/19] add doc Signed-off-by: Can-Zhao --- docs/source/apps.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 6e7a83893c..d74a8a54f3 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -267,5 +267,5 @@ FastMRIReader `MAISI Utilities` ~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: monai.apps.generation/maisi/utils/morphological_ops +.. automodule:: monai.apps.generation.maisi.utils.morphological_ops :members: From 263a36872ea3fe172860568c1c7e4ec822e476b0 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 05:35:15 +0000 Subject: [PATCH 17/19] update doc Signed-off-by: Can-Zhao --- docs/source/apps.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index d74a8a54f3..c6ba8c0b9a 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -263,9 +263,9 @@ FastMRIReader :members: `Generative AI` ------------ +--------------- `MAISI Utilities` -~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~ .. automodule:: monai.apps.generation.maisi.utils.morphological_ops :members: From c090186d8e3c5c9bb63107f4fc2e71252d287dbc Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 06:12:34 +0000 Subject: [PATCH 18/19] update test import Signed-off-by: Can-Zhao --- tests/test_morphological_ops.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index 43ae1eb8cc..6f29415759 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.utils import morphological_ops +from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_SHAPE = [] @@ -80,23 +80,21 @@ class TestMorph(unittest.TestCase): @parameterized.expand(TESTS_SHAPE) def test_shape(self, input_data, expected_result): - result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"]) + result1 = erode(input_data["mask"], input_data["filter_size"]) assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0) @parameterized.expand(TESTS_VALUE_T) def test_value_t(self, input_data, expected_result): - result1 = morphological_ops.get_morphological_filter_result_t( + result1 = get_morphological_filter_result_t( input_data["mask"], input_data["filter_size"], input_data["pad_value"] ) - # result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"]) - # assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0) @parameterized.expand(TESTS_VALUE) def test_value(self, input_data, expected_erode_result, expected_dilate_result): - result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"]) + result1 = erode(input_data["mask"], input_data["filter_size"]) assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) - result2 = morphological_ops.dilate(input_data["mask"], input_data["filter_size"]) + result2 = dilate(input_data["mask"], input_data["filter_size"]) assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0) From 451e5dafd735460e10af92d1fed9f7e2d9253f3e Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 2 Jul 2024 06:27:58 +0000 Subject: [PATCH 19/19] add init Signed-off-by: Can-Zhao --- monai/apps/generation/maisi/utils/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 monai/apps/generation/maisi/utils/__init__.py diff --git a/monai/apps/generation/maisi/utils/__init__.py b/monai/apps/generation/maisi/utils/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.