From 1e43d0d5ea2153346880704fae92668c63e1c0d0 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Tue, 18 Jun 2024 22:01:16 +0000 Subject: [PATCH 01/37] init Signed-off-by: dongyang0122 --- monai/apps/generation/__init__.py | 10 + monai/apps/generation/maisi/__init__.py | 10 + .../generation/maisi/networks/__init__.py | 10 + .../maisi/networks/autoencoderkl_streaming.py | 935 ++++++++++++++++++ 4 files changed, 965 insertions(+) create mode 100644 monai/apps/generation/__init__.py create mode 100644 monai/apps/generation/maisi/__init__.py create mode 100644 monai/apps/generation/maisi/networks/__init__.py create mode 100644 monai/apps/generation/maisi/networks/autoencoderkl_streaming.py diff --git a/monai/apps/generation/__init__.py b/monai/apps/generation/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/__init__.py b/monai/apps/generation/maisi/networks/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/networks/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py b/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py new file mode 100644 index 0000000000..204751d79f --- /dev/null +++ b/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py @@ -0,0 +1,935 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import monai +import numpy as np +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Sequence +from monai.networks.blocks import Convolution +from generative.networks.nets.autoencoderkl import ( + AttentionBlock, + ResBlock, + AutoencoderKL, + Encoder, +) + + +NUM_SPLITS = 16 + +class InplaceGroupNorm3D(torch.nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(InplaceGroupNorm3D, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input): + print("InplaceGroupNorm3D in", input.size()) + + # Ensure the tensor is 5D: (N, C, D, H, W) + if len(input.shape) != 5: + raise ValueError("Expected a 5D tensor") + + N, C, D, H, W = input.shape + + # Reshape to (N, num_groups, C // num_groups, D, H, W) + input = input.view(N, self.num_groups, C // self.num_groups, D, H, W) + + means, stds = [], [] + inputs = [] + for _i in range(input.size(1)): + array = input[:, _i : _i + 1, ...] + array = array.to(dtype=torch.float32) + _mean = array.mean([2, 3, 4, 5], keepdim=True) + _std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() + + _mean = _mean.to(dtype=torch.float32) + _std = _std.to(dtype=torch.float32) + + inputs.append(array.sub_(_mean).div_(_std).to(dtype=torch.float16)) + + del input + torch.cuda.empty_cache() + + if max(inputs[0].size()) < 500: + input = torch.cat([inputs[_k] for _k in range(len(inputs))], dim=1) + else: + _type = inputs[0].device.type + if _type == "cuda": + input = inputs[0].clone().to("cpu", non_blocking=True) + else: + input = inputs[0].clone() + inputs[0] = 0 + torch.cuda.empty_cache() + + for _k in range(len(inputs) - 1): + input = torch.cat((input, inputs[_k + 1].cpu()), dim=1) + inputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + print(f"InplaceGroupNorm3D cat: {_k + 1}/{len(inputs) - 1}.") + + if _type == "cuda": + input = input.to("cuda", non_blocking=True) + + # Reshape back to original size + input = input.view(N, C, D, H, W) + + # Apply affine transformation if enabled + if self.affine: + input.mul_(self.weight.view(1, C, 1, 1, 1)).add_(self.bias.view(1, C, 1, 1, 1)) + + print("InplaceGroupNorm3D out", input.size()) + + return input + + +class StreamingConvolution(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, + adn_ordering: str = "NDA", + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, + dropout_dim: int | None = 1, + dilation: Sequence[int] | int = 1, + groups: int = 1, + bias: bool = True, + conv_only: bool = False, + is_transposed: bool = False, + padding: Sequence[int] | int | None = None, + output_padding: Sequence[int] | int | None = None, + ) -> None: + super(StreamingConvolution, self).__init__() + self.conv = monai.networks.blocks.convolutions.Convolution( + spatial_dims, + in_channels, + out_channels, + strides, + kernel_size, + adn_ordering, + act, + norm, + dropout, + dropout_dim, + dilation, + groups, + bias, + conv_only, + is_transposed, + padding, + output_padding, + ) + + self.tp_dim = 1 + self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides + + def forward(self, x): + num_splits = NUM_SPLITS + print("num_splits:", num_splits) + l = x.size(self.tp_dim + 2) + split_size = l // num_splits + + padding = 3 + if padding % self.stride > 0: + padding = (padding // self.stride + 1) * self.stride + print("padding:", padding) + + overlaps = [0] + [padding] * (num_splits - 1) + last_padding = x.size(self.tp_dim + 2) % split_size + + if self.tp_dim == 0: + splits = [ + x[ + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + :, + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 1: + splits = [ + x[ + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 2: + splits = [ + x[ + :, + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + ] + for i in range(num_splits) + ] + + for _j in range(len(splits)): + print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + + del x + torch.cuda.empty_cache() + + splits_0_size = list(splits[0].size()) + print("splits_0_size:", splits_0_size) + + outputs = [] + _type = splits[0].device.type + for _i in range(num_splits): + outputs.append(self.conv(splits[_i])) + + splits[_i] = 0 + torch.cuda.empty_cache() + + for _j in range(len(outputs)): + print(f"outputs before {_j + 1}/{len(outputs)}:", outputs[_j].size()) + + del splits + torch.cuda.empty_cache() + + split_size_out = split_size + padding_s = padding + non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 + if outputs[0].size(non_tp_dim + 2) // splits_0_size[non_tp_dim + 2] == 2: + split_size_out *= 2 + padding_s *= 2 + elif splits_0_size[non_tp_dim + 2] // outputs[0].size(non_tp_dim + 2) == 2: + split_size_out = split_size_out // 2 + padding_s = padding_s // 2 + + if self.tp_dim == 0: + outputs[0] = outputs[0][:, :, :split_size_out, :, :] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] + elif self.tp_dim == 1: + print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + outputs[0] = outputs[0][:, :, :, :split_size_out, :] + for i in range(1, num_splits): + print( + "outputs", + outputs[i].size(3), + f"padding_s: {padding_s}, {padding_s + split_size_out}", + ) + outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] + elif self.tp_dim == 2: + outputs[0] = outputs[0][:, :, :, :, :split_size_out] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] + + for i in range(num_splits): + print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + + if max(outputs[0].size()) < 500: + print(f"outputs[0].device.type: {outputs[0].device.type}.") + x = torch.cat([out for out in outputs], dim=self.tp_dim + 2) + else: + _type = outputs[0].device.type + if _type == "cuda": + x = outputs[0].clone().to("cpu", non_blocking=True) + outputs[0] = 0 + torch.cuda.empty_cache() + for _k in range(len(outputs) - 1): + x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) + outputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + print(f"StreamingConvolution cat: {_k + 1}/{len(outputs) - 1}.") + if _type == "cuda": + x = x.to("cuda", non_blocking=True) + + del outputs + torch.cuda.empty_cache() + + return x + + +class StreamingUpsample(nn.Module): + """ + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + + dtype = x.dtype + + x = F.interpolate(x, scale_factor=2.0, mode="trilinear") + torch.cuda.empty_cache() + + x = self.conv(x) + torch.cuda.empty_cache() + + return x + + +class StreamingDownsample(nn.Module): + """ + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class StreamingResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = InplaceGroupNorm3D( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=norm_eps, + affine=True, + ) + self.conv1 = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = InplaceGroupNorm3D( + num_groups=norm_num_groups, + num_channels=out_channels, + eps=norm_eps, + affine=True, + ) + self.conv2 = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + torch.cuda.empty_cache() + + h = F.silu(h) + torch.cuda.empty_cache() + h = self.conv1(h) + torch.cuda.empty_cache() + + h = self.norm2(h) + torch.cuda.empty_cache() + + h = F.silu(h) + torch.cuda.empty_cache() + h = self.conv2(h) + torch.cuda.empty_cache() + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + torch.cuda.empty_cache() + + return x + h + + +class StreamingEncoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + num_channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks = [] + + # Initial convolution + blocks.append( + StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + StreamingResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(StreamingDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append( + InplaceGroupNorm3D( + num_groups=norm_num_groups, + num_channels=num_channels[-1], + eps=norm_eps, + affine=True, + ) + ) + blocks.append( + StreamingConvolution( + spatial_dims=self.spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + torch.cuda.empty_cache() + return x + + +class StreamingDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + tp_dim: int = 1, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.tp_dim = tp_dim + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks = [] + + # Initial convolution + blocks.append( + StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + StreamingResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + StreamingUpsample( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + use_convtranspose=use_convtranspose, + ) + ) + + blocks.append( + InplaceGroupNorm3D( + num_groups=norm_num_groups, + num_channels=block_in_ch, + eps=norm_eps, + affine=True, + ) + ) + blocks.append( + StreamingConvolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for _i in range(len(self.blocks)): + block = self.blocks[_i] + print(block, type(block), type(type(block))) + + if _i < len(self.blocks) - 0: + x = block(x) + torch.cuda.empty_cache() + else: + num_splits = NUM_SPLITS + print("num_splits:", num_splits) + + l = x.size(self.tp_dim + 2) + split_size = l // num_splits + + padding = 3 + print("padding:", padding) + + overlaps = [0] + [padding] * (num_splits - 1) + if self.tp_dim == 0: + splits = [ + x[ + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + :, + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 1: + splits = [ + x[ + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 2: + splits = [ + x[ + :, + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + ] + for i in range(num_splits) + ] + + for _j in range(len(splits)): + print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + + del x + torch.cuda.empty_cache() + + outputs = [block(splits[i]) for i in range(num_splits)] + + del splits + torch.cuda.empty_cache() + + split_size_out = split_size + padding_s = padding + non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 + if outputs[0].size(non_tp_dim + 2) // splits[0].size(non_tp_dim + 2) == 2: + split_size_out *= 2 + padding_s *= 2 + print("split_size_out:", split_size_out) + print("padding_s:", padding_s) + + if self.tp_dim == 0: + outputs[0] = outputs[0][:, :, :split_size_out, :, :] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] + elif self.tp_dim == 1: + print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + outputs[0] = outputs[0][:, :, :, :split_size_out, :] + for i in range(1, num_splits): + print( + "outputs", + outputs[i].size(3), + f"padding_s: {padding_s}, {padding_s + split_size_out}", + ) + outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] + elif self.tp_dim == 2: + outputs[0] = outputs[0][:, :, :, :, :split_size_out] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] + + for i in range(num_splits): + print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + + if max(outputs[0].size()) < 500: + x = torch.cat([out for out in outputs], dim=self.tp_dim + 2) + else: + x = outputs[0].clone().to("cpu", non_blocking=True) + outputs[0] = 0 + torch.cuda.empty_cache() + for _k in range(len(outputs) - 1): + x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) + outputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + print(f"cat: {_k + 1}/{len(outputs) - 1}.") + x = x.to("cuda", non_blocking=True) + + del outputs + torch.cuda.empty_cache() + + return x + + +class StreamingAutoencoderKL(AutoencoderKL): + """ + Override encoder to make it align with original ldm codebase and support activation checkpointing. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + num_channels: Sequence[int], + attention_levels: Sequence[bool], + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__( + spatial_dims, + in_channels, + out_channels, + num_res_blocks, + num_channels, + attention_levels, + latent_channels, + norm_num_groups, + norm_eps, + with_encoder_nonlocal_attn, + with_decoder_nonlocal_attn, + use_flash_attention, + use_checkpointing, + use_convtranspose, + ) + + self.encoder = StreamingEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + + # Override decoder using transposed conv + self.decoder = StreamingDecoder( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) From 356230180d35aa1593d6364eb4f1d62d35c6416f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 22:06:17 +0000 Subject: [PATCH 02/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../generation/maisi/networks/autoencoderkl_streaming.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py b/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py index 204751d79f..9415d971d3 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py @@ -11,19 +11,15 @@ import gc import monai -import numpy as np -import os import torch import torch.nn as nn import torch.nn.functional as F from typing import Sequence -from monai.networks.blocks import Convolution from generative.networks.nets.autoencoderkl import ( AttentionBlock, ResBlock, AutoencoderKL, - Encoder, ) @@ -31,7 +27,7 @@ class InplaceGroupNorm3D(torch.nn.GroupNorm): def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - super(InplaceGroupNorm3D, self).__init__(num_groups, num_channels, eps, affine) + super().__init__(num_groups, num_channels, eps, affine) def forward(self, input): print("InplaceGroupNorm3D in", input.size()) @@ -115,7 +111,7 @@ def __init__( padding: Sequence[int] | int | None = None, output_padding: Sequence[int] | int | None = None, ) -> None: - super(StreamingConvolution, self).__init__() + super().__init__() self.conv = monai.networks.blocks.convolutions.Convolution( spatial_dims, in_channels, From 42ee14a2716e708fb9ca43318485c354b927478c Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 15:10:53 +0000 Subject: [PATCH 03/37] update Signed-off-by: dongyang0122 --- ...kl_streaming.py => autoencoderkl_maisi.py} | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) rename monai/apps/generation/maisi/networks/{autoencoderkl_streaming.py => autoencoderkl_maisi.py} (96%) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py similarity index 96% rename from monai/apps/generation/maisi/networks/autoencoderkl_streaming.py rename to monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 204751d79f..39faa1297d 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_streaming.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -27,7 +27,7 @@ ) -NUM_SPLITS = 16 +# NUM_SPLITS = 16 class InplaceGroupNorm3D(torch.nn.GroupNorm): def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): @@ -100,6 +100,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, + num_splits: int, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", @@ -138,9 +139,11 @@ def __init__( self.tp_dim = 1 self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides + self.num_splits = num_splits def forward(self, x): - num_splits = NUM_SPLITS + # num_splits = NUM_SPLITS + num_splits = self.num_splits print("num_splits:", num_splits) l = x.size(self.tp_dim + 2) split_size = l // num_splits @@ -282,7 +285,7 @@ class StreamingUpsample(nn.Module): use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ - def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int) -> None: super().__init__() if use_convtranspose: self.conv = StreamingConvolution( @@ -294,6 +297,7 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) padding=1, conv_only=True, is_transposed=True, + num_splits=num_splits, ) else: self.conv = StreamingConvolution( @@ -304,6 +308,7 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) self.use_convtranspose = use_convtranspose @@ -331,7 +336,7 @@ class StreamingDownsample(nn.Module): in_channels: number of input channels. """ - def __init__(self, spatial_dims: int, in_channels: int) -> None: + def __init__(self, spatial_dims: int, in_channels: int, num_splits: int) -> None: super().__init__() self.pad = (0, 1) * spatial_dims @@ -343,6 +348,7 @@ def __init__(self, spatial_dims: int, in_channels: int) -> None: kernel_size=3, padding=0, conv_only=True, + num_splits=num_splits, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -372,6 +378,7 @@ def __init__( norm_num_groups: int, norm_eps: float, out_channels: int, + num_splits: int, ) -> None: super().__init__() self.in_channels = in_channels @@ -391,6 +398,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) self.norm2 = InplaceGroupNorm3D( num_groups=norm_num_groups, @@ -406,6 +414,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) if self.in_channels != self.out_channels: @@ -417,6 +426,7 @@ def __init__( kernel_size=1, padding=0, conv_only=True, + num_splits=num_splits, ) else: self.nin_shortcut = nn.Identity() @@ -473,6 +483,7 @@ def __init__( norm_num_groups: int, norm_eps: float, attention_levels: Sequence[bool], + num_splits: int, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: @@ -485,6 +496,7 @@ def __init__( self.norm_num_groups = norm_num_groups self.norm_eps = norm_eps self.attention_levels = attention_levels + self.num_splits = num_splits blocks = [] @@ -498,6 +510,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) ) @@ -516,6 +529,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, out_channels=output_channel, + num_splits=num_splits, ) ) input_channel = output_channel @@ -531,7 +545,7 @@ def __init__( ) if not is_final_block: - blocks.append(StreamingDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + blocks.append(StreamingDownsample(spatial_dims=spatial_dims, in_channels=input_channel, num_splits=num_splits)) # Non-local attention block if with_nonlocal_attn is True: @@ -581,6 +595,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) ) @@ -621,6 +636,7 @@ def __init__( norm_num_groups: int, norm_eps: float, attention_levels: Sequence[bool], + num_splits: int, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, use_convtranspose: bool = False, @@ -635,6 +651,7 @@ def __init__( self.norm_num_groups = norm_num_groups self.norm_eps = norm_eps self.attention_levels = attention_levels + self.num_splits = num_splits self.tp_dim = tp_dim reversed_block_out_channels = list(reversed(num_channels)) @@ -651,6 +668,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) ) @@ -700,6 +718,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, out_channels=block_out_ch, + num_splits=num_splits, ) ) block_in_ch = block_out_ch @@ -721,6 +740,7 @@ def __init__( spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose, + num_splits=num_splits, ) ) @@ -741,6 +761,7 @@ def __init__( kernel_size=3, padding=1, conv_only=True, + num_splits=num_splits, ) ) @@ -755,7 +776,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = block(x) torch.cuda.empty_cache() else: - num_splits = NUM_SPLITS + # num_splits = NUM_SPLITS + num_splits = self.num_splits print("num_splits:", num_splits) l = x.size(self.tp_dim + 2) @@ -867,7 +889,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class StreamingAutoencoderKL(AutoencoderKL): +class AutoencoderKlMaisi(AutoencoderKL): """ Override encoder to make it align with original ldm codebase and support activation checkpointing. """ @@ -888,6 +910,7 @@ def __init__( use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, + num_splits: int = 16, ) -> None: super().__init__( spatial_dims, @@ -917,6 +940,7 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, use_flash_attention=use_flash_attention, + num_splits=num_splits, ) # Override decoder using transposed conv @@ -932,4 +956,5 @@ def __init__( with_nonlocal_attn=with_decoder_nonlocal_attn, use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, + num_splits=num_splits, ) From 40609498dba12308ca1a62b9a50873d2751c37d2 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 15:51:29 +0000 Subject: [PATCH 04/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 197 +++++++++++------- 1 file changed, 127 insertions(+), 70 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index a52d4d7346..c6ee8a3f93 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -23,14 +23,14 @@ ) -# NUM_SPLITS = 16 - -class InplaceGroupNorm3D(torch.nn.GroupNorm): - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): +class MaisiGroupNorm3D(torch.nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, debug=True): super().__init__(num_groups, num_channels, eps, affine) + self.debug = debug def forward(self, input): - print("InplaceGroupNorm3D in", input.size()) + if self.debug: + print("MaisiGroupNorm3D in", input.size()) # Ensure the tensor is 5D: (N, C, D, H, W) if len(input.shape) != 5: @@ -73,7 +73,9 @@ def forward(self, input): inputs[_k + 1] = 0 torch.cuda.empty_cache() gc.collect() - print(f"InplaceGroupNorm3D cat: {_k + 1}/{len(inputs) - 1}.") + + if self.debug: + print(f"MaisiGroupNorm3D cat: {_k + 1}/{len(inputs) - 1}.") if _type == "cuda": input = input.to("cuda", non_blocking=True) @@ -85,18 +87,20 @@ def forward(self, input): if self.affine: input.mul_(self.weight.view(1, C, 1, 1, 1)).add_(self.bias.view(1, C, 1, 1, 1)) - print("InplaceGroupNorm3D out", input.size()) + if self.debug: + print("MaisiGroupNorm3D out", input.size()) return input -class StreamingConvolution(nn.Module): +class MaisiConvolution(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_splits: int, + debug: bool, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", @@ -136,18 +140,21 @@ def __init__( self.tp_dim = 1 self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides self.num_splits = num_splits + self.debug = debug def forward(self, x): - # num_splits = NUM_SPLITS num_splits = self.num_splits - print("num_splits:", num_splits) + if self.debug: + print("num_splits:", num_splits) + l = x.size(self.tp_dim + 2) split_size = l // num_splits padding = 3 if padding % self.stride > 0: padding = (padding // self.stride + 1) * self.stride - print("padding:", padding) + if self.debug: + print("padding:", padding) overlaps = [0] + [padding] * (num_splits - 1) last_padding = x.size(self.tp_dim + 2) % split_size @@ -192,14 +199,16 @@ def forward(self, x): for i in range(num_splits) ] - for _j in range(len(splits)): - print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + if self.debug: + for _j in range(len(splits)): + print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) del x torch.cuda.empty_cache() splits_0_size = list(splits[0].size()) - print("splits_0_size:", splits_0_size) + if self.debug: + print("splits_0_size:", splits_0_size) outputs = [] _type = splits[0].device.type @@ -209,8 +218,9 @@ def forward(self, x): splits[_i] = 0 torch.cuda.empty_cache() - for _j in range(len(outputs)): - print(f"outputs before {_j + 1}/{len(outputs)}:", outputs[_j].size()) + if self.debug: + for _j in range(len(outputs)): + print(f"outputs before {_j + 1}/{len(outputs)}:", outputs[_j].size()) del splits torch.cuda.empty_cache() @@ -230,25 +240,29 @@ def forward(self, x): for i in range(1, num_splits): outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] elif self.tp_dim == 1: - print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + if self.debug: + print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") outputs[0] = outputs[0][:, :, :, :split_size_out, :] for i in range(1, num_splits): - print( - "outputs", - outputs[i].size(3), - f"padding_s: {padding_s}, {padding_s + split_size_out}", - ) + if self.debug: + print( + "outputs", + outputs[i].size(3), + f"padding_s: {padding_s}, {padding_s + split_size_out}", + ) outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] elif self.tp_dim == 2: outputs[0] = outputs[0][:, :, :, :, :split_size_out] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] - for i in range(num_splits): - print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + if self.debug: + for i in range(num_splits): + print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) if max(outputs[0].size()) < 500: - print(f"outputs[0].device.type: {outputs[0].device.type}.") + if self.debug: + print(f"outputs[0].device.type: {outputs[0].device.type}.") x = torch.cat([out for out in outputs], dim=self.tp_dim + 2) else: _type = outputs[0].device.type @@ -261,7 +275,8 @@ def forward(self, x): outputs[_k + 1] = 0 torch.cuda.empty_cache() gc.collect() - print(f"StreamingConvolution cat: {_k + 1}/{len(outputs) - 1}.") + if self.debug: + print(f"MaisiConvolution cat: {_k + 1}/{len(outputs) - 1}.") if _type == "cuda": x = x.to("cuda", non_blocking=True) @@ -271,7 +286,7 @@ def forward(self, x): return x -class StreamingUpsample(nn.Module): +class MaisiUpsample(nn.Module): """ Convolution-based upsampling layer. @@ -281,10 +296,12 @@ class StreamingUpsample(nn.Module): use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ - def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int) -> None: + def __init__( + self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int, debug: bool + ) -> None: super().__init__() if use_convtranspose: - self.conv = StreamingConvolution( + self.conv = MaisiConvolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -294,9 +311,10 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool, conv_only=True, is_transposed=True, num_splits=num_splits, + debug=debug, ) else: - self.conv = StreamingConvolution( + self.conv = MaisiConvolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -305,6 +323,7 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool, padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) self.use_convtranspose = use_convtranspose @@ -323,7 +342,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class StreamingDownsample(nn.Module): +class MaisiDownsample(nn.Module): """ Convolution-based downsampling layer. @@ -332,11 +351,11 @@ class StreamingDownsample(nn.Module): in_channels: number of input channels. """ - def __init__(self, spatial_dims: int, in_channels: int, num_splits: int) -> None: + def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, debug: bool) -> None: super().__init__() self.pad = (0, 1) * spatial_dims - self.conv = StreamingConvolution( + self.conv = MaisiConvolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -345,6 +364,7 @@ def __init__(self, spatial_dims: int, in_channels: int, num_splits: int) -> None padding=0, conv_only=True, num_splits=num_splits, + debug=debug, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -353,7 +373,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class StreamingResBlock(nn.Module): +class MaisiResBlock(nn.Module): """ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a residual connection between input and output. @@ -375,18 +395,20 @@ def __init__( norm_eps: float, out_channels: int, num_splits: int, + debug: bool, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = InplaceGroupNorm3D( + self.norm1 = MaisiGroupNorm3D( num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True, + debug=debug, ) - self.conv1 = StreamingConvolution( + self.conv1 = MaisiConvolution( spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.out_channels, @@ -395,14 +417,16 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) - self.norm2 = InplaceGroupNorm3D( + self.norm2 = MaisiGroupNorm3D( num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True, + debug=debug, ) - self.conv2 = StreamingConvolution( + self.conv2 = MaisiConvolution( spatial_dims=spatial_dims, in_channels=self.out_channels, out_channels=self.out_channels, @@ -411,10 +435,11 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) if self.in_channels != self.out_channels: - self.nin_shortcut = StreamingConvolution( + self.nin_shortcut = MaisiConvolution( spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.out_channels, @@ -423,6 +448,7 @@ def __init__( padding=0, conv_only=True, num_splits=num_splits, + debug=debug, ) else: self.nin_shortcut = nn.Identity() @@ -452,7 +478,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class StreamingEncoder(nn.Module): +class MaisiEncoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -480,6 +506,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + debug: bool, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: @@ -498,7 +525,7 @@ def __init__( # Initial convolution blocks.append( - StreamingConvolution( + MaisiConvolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=num_channels[0], @@ -507,6 +534,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) ) @@ -519,13 +547,14 @@ def __init__( for _ in range(self.num_res_blocks[i]): blocks.append( - StreamingResBlock( + MaisiResBlock( spatial_dims=spatial_dims, in_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, out_channels=output_channel, num_splits=num_splits, + debug=debug, ) ) input_channel = output_channel @@ -541,7 +570,11 @@ def __init__( ) if not is_final_block: - blocks.append(StreamingDownsample(spatial_dims=spatial_dims, in_channels=input_channel, num_splits=num_splits)) + blocks.append( + MaisiDownsample( + spatial_dims=spatial_dims, in_channels=input_channel, num_splits=num_splits, debug=debug + ) + ) # Non-local attention block if with_nonlocal_attn is True: @@ -575,15 +608,16 @@ def __init__( ) # Normalise and convert to latent size blocks.append( - InplaceGroupNorm3D( + MaisiGroupNorm3D( num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True, + debug=debug, ) ) blocks.append( - StreamingConvolution( + MaisiConvolution( spatial_dims=self.spatial_dims, in_channels=num_channels[-1], out_channels=out_channels, @@ -592,6 +626,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) ) @@ -604,7 +639,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class StreamingDecoder(nn.Module): +class MaisiDecoder(nn.Module): """ Convolutional cascade upsampling from a spatial latent space into an image space. @@ -633,6 +668,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + debug: bool, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, use_convtranspose: bool = False, @@ -648,6 +684,7 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels self.num_splits = num_splits + self.debug = debug self.tp_dim = tp_dim reversed_block_out_channels = list(reversed(num_channels)) @@ -656,7 +693,7 @@ def __init__( # Initial convolution blocks.append( - StreamingConvolution( + MaisiConvolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=reversed_block_out_channels[0], @@ -665,6 +702,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) ) @@ -708,13 +746,14 @@ def __init__( for _ in range(reversed_num_res_blocks[i]): blocks.append( - StreamingResBlock( + MaisiResBlock( spatial_dims=spatial_dims, in_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, out_channels=block_out_ch, num_splits=num_splits, + debug=debug, ) ) block_in_ch = block_out_ch @@ -732,24 +771,26 @@ def __init__( if not is_final_block: blocks.append( - StreamingUpsample( + MaisiUpsample( spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose, num_splits=num_splits, + debug=debug, ) ) blocks.append( - InplaceGroupNorm3D( + MaisiGroupNorm3D( num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True, + debug=debug, ) ) blocks.append( - StreamingConvolution( + MaisiConvolution( spatial_dims=spatial_dims, in_channels=block_in_ch, out_channels=out_channels, @@ -758,6 +799,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + debug=debug, ) ) @@ -766,21 +808,25 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for _i in range(len(self.blocks)): block = self.blocks[_i] - print(block, type(block), type(type(block))) + + if self.debug: + print(block, type(block), type(type(block))) if _i < len(self.blocks) - 0: x = block(x) torch.cuda.empty_cache() else: - # num_splits = NUM_SPLITS num_splits = self.num_splits - print("num_splits:", num_splits) + + if self.debug: + print("num_splits:", num_splits) l = x.size(self.tp_dim + 2) split_size = l // num_splits padding = 3 - print("padding:", padding) + if self.debug: + print("padding:", padding) overlaps = [0] + [padding] * (num_splits - 1) if self.tp_dim == 0: @@ -823,8 +869,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i in range(num_splits) ] - for _j in range(len(splits)): - print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + if debug: + for _j in range(len(splits)): + print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) del x torch.cuda.empty_cache() @@ -840,30 +887,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if outputs[0].size(non_tp_dim + 2) // splits[0].size(non_tp_dim + 2) == 2: split_size_out *= 2 padding_s *= 2 - print("split_size_out:", split_size_out) - print("padding_s:", padding_s) + + if self.debug: + print("split_size_out:", split_size_out) + print("padding_s:", padding_s) if self.tp_dim == 0: outputs[0] = outputs[0][:, :, :split_size_out, :, :] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] elif self.tp_dim == 1: - print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + if self.debug: + print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + outputs[0] = outputs[0][:, :, :, :split_size_out, :] for i in range(1, num_splits): - print( - "outputs", - outputs[i].size(3), - f"padding_s: {padding_s}, {padding_s + split_size_out}", - ) + if self.debug: + print( + "outputs", + outputs[i].size(3), + f"padding_s: {padding_s}, {padding_s + split_size_out}", + ) outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] elif self.tp_dim == 2: outputs[0] = outputs[0][:, :, :, :, :split_size_out] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] - for i in range(num_splits): - print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + if self.debug: + for i in range(num_splits): + print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) if max(outputs[0].size()) < 500: x = torch.cat([out for out in outputs], dim=self.tp_dim + 2) @@ -876,7 +929,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: outputs[_k + 1] = 0 torch.cuda.empty_cache() gc.collect() - print(f"cat: {_k + 1}/{len(outputs) - 1}.") + if self.debug: + print(f"cat: {_k + 1}/{len(outputs) - 1}.") x = x.to("cuda", non_blocking=True) del outputs @@ -907,6 +961,7 @@ def __init__( use_checkpointing: bool = False, use_convtranspose: bool = False, num_splits: int = 16, + debug: bool = True, ) -> None: super().__init__( spatial_dims, @@ -925,7 +980,7 @@ def __init__( use_convtranspose, ) - self.encoder = StreamingEncoder( + self.encoder = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -937,10 +992,11 @@ def __init__( with_nonlocal_attn=with_encoder_nonlocal_attn, use_flash_attention=use_flash_attention, num_splits=num_splits, + debug=debug, ) # Override decoder using transposed conv - self.decoder = StreamingDecoder( + self.decoder = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, @@ -953,4 +1009,5 @@ def __init__( use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, + debug=debug, ) From e415abd1a5844a16982275ae740bf7b779341181 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 16:30:25 +0000 Subject: [PATCH 05/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 981f126020..9f05a237e0 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -96,6 +96,7 @@ class MaisiConvolution(nn.Module): num_splits: Number of splits for the input tensor. debug: Whether to print debug information. Additional arguments for the convolution operation. + https://docs.monai.io/en/stable/networks.html#convolution """ def __init__( From c7d68d28382704ee79f88e06596a5c57fb844ae5 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Wed, 19 Jun 2024 10:41:34 -0600 Subject: [PATCH 06/37] update Signed-off-by: Dong Yang --- tests/test_autoencoderkl_maisi.py | 256 ++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 tests/test_autoencoderkl_maisi.py diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py new file mode 100644 index 0000000000..8d7d77d3b4 --- /dev/null +++ b/tests/test_autoencoderkl_maisi.py @@ -0,0 +1,256 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config +from autoencoder_kl_maisi import AutoencoderKlMaisi # Assuming the class is in a file named autoencoder_kl_maisi.py + +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") +_, has_einops = optional_import("einops") + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +CASES_NO_ATTENTION = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "num_splits": 1, + "debug": False, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +CASES_ATTENTION = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_splits": 1, + "debug": False, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +if has_einops: + CASES = CASES_NO_ATTENTION + CASES_ATTENTION +else: + CASES = CASES_NO_ATTENTION + + +class TestAutoencoderKlMaisi(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKlMaisi( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + num_splits=1, + debug=False, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKlMaisi( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + num_splits=1, + debug=False, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKlMaisi( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + num_splits=1, + debug=False, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKlMaisi(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = AutoencoderKlMaisi( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, True), + num_res_blocks=1, + norm_num_groups=4, + num_splits=1, + debug=False, + ).to(device) + + tmpdir = tempfile.mkdtemp() + key = "autoencoderkl_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "autoencoderkl_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + +if __name__ == "__main__": + unittest.main() From aab1ba3d14e5e0a1a073d637bdd3b9c9c928a06b Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 17:00:04 +0000 Subject: [PATCH 07/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 72 +++++++++---------- tests/test_autoencoderkl_maisi.py | 6 +- 2 files changed, 36 insertions(+), 42 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 9f05a237e0..bf4dd831f9 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -9,15 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import gc from typing import Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F +from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock import monai -from generative.networks.nets.autoencoderkl import AttentionBlock, ResBlock, AutoencoderKL class MaisiGroupNorm3D(nn.GroupNorm): @@ -43,12 +45,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if len(input.shape) != 5: raise ValueError("Expected a 5D tensor") - N, C, D, H, W = input.shape - input = input.view(N, self.num_groups, C // self.num_groups, D, H, W) + param_n, param_c, param_d, param_h, param_w = input.shape + input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w) inputs = [] for i in range(input.size(1)): - array = input[:, i:i + 1, ...].to(dtype=torch.float32) + array = input[:, i : i + 1, ...].to(dtype=torch.float32) mean = array.mean([2, 3, 4, 5], keepdim=True) std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() inputs.append(array.sub_(mean).div_(std).to(dtype=torch.float16)) @@ -56,11 +58,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: del input torch.cuda.empty_cache() - input = torch.cat([inputs[k] for k in range(len(inputs))], dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs) + input = ( + torch.cat([inputs[k] for k in range(len(inputs))], dim=1) + if max(inputs[0].size()) < 500 + else self._cat_inputs(inputs) + ) - input = input.view(N, C, D, H, W) + input = input.view(param_n, param_c, param_d, param_h, param_w) if self.affine: - input.mul_(self.weight.view(1, C, 1, 1, 1)).add_(self.bias.view(1, C, 1, 1, 1)) + input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1)) if self.debug: print("MaisiGroupNorm3D out", input.size()) @@ -242,7 +248,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) if max(outputs[0].size()) < 500: - x = torch.cat([out for out in outputs], dim=self.tp_dim + 2) + x = torch.cat(outputs, dim=self.tp_dim + 2) else: x = outputs[0].clone().to("cpu", non_blocking=True) outputs[0] = 0 @@ -363,11 +369,7 @@ def __init__( self.out_channels = in_channels if out_channels is None else out_channels self.norm1 = MaisiGroupNorm3D( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=norm_eps, - affine=True, - debug=debug, + num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True, debug=debug ) self.conv1 = MaisiConvolution( spatial_dims=spatial_dims, @@ -381,11 +383,7 @@ def __init__( debug=debug, ) self.norm2 = MaisiGroupNorm3D( - num_groups=norm_num_groups, - num_channels=out_channels, - eps=norm_eps, - affine=True, - debug=debug, + num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True, debug=debug ) self.conv2 = MaisiConvolution( spatial_dims=spatial_dims, @@ -399,17 +397,21 @@ def __init__( debug=debug, ) - self.nin_shortcut = MaisiConvolution( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - num_splits=num_splits, - debug=debug, - ) if self.in_channels != self.out_channels else nn.Identity() + self.nin_shortcut = ( + MaisiConvolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + num_splits=num_splits, + debug=debug, + ) + if self.in_channels != self.out_channels + else nn.Identity() + ) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) @@ -565,11 +567,7 @@ def __init__( blocks.append( MaisiGroupNorm3D( - num_groups=norm_num_groups, - num_channels=num_channels[-1], - eps=norm_eps, - affine=True, - debug=debug, + num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True, debug=debug ) ) blocks.append( @@ -736,11 +734,7 @@ def __init__( blocks.append( MaisiGroupNorm3D( - num_groups=norm_num_groups, - num_channels=block_in_ch, - eps=norm_eps, - affine=True, - debug=debug, + num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True, debug=debug ) ) blocks.append( diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 8d7d77d3b4..63716b7c45 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -20,10 +20,10 @@ from parameterized import parameterized from monai.apps import download_url +from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config -from autoencoder_kl_maisi import AutoencoderKlMaisi # Assuming the class is in a file named autoencoder_kl_maisi.py tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") @@ -50,7 +50,7 @@ (1, 1, 16, 16, 16), (1, 1, 16, 16, 16), (1, 4, 4, 4, 4), - ], + ] ] CASES_ATTENTION = [ @@ -70,7 +70,7 @@ (1, 1, 16, 16, 16), (1, 1, 16, 16, 16), (1, 4, 4, 4, 4), - ], + ] ] if has_einops: From 31420c532861976f4f02e4e377b9c05edef5a508 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:00:38 +0000 Subject: [PATCH 08/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../maisi/networks/autoencoderkl_maisi.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index bf4dd831f9..185d8ec802 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -12,7 +12,7 @@ from __future__ import annotations import gc -from typing import Sequence, Union +from typing import Sequence import torch import torch.nn as nn @@ -112,20 +112,20 @@ def __init__( out_channels: int, num_splits: int, debug: bool, - strides: Union[Sequence[int], int] = 1, - kernel_size: Union[Sequence[int], int] = 3, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", - act: Union[tuple, str, None] = "PRELU", - norm: Union[tuple, str, None] = "INSTANCE", - dropout: Union[tuple, str, float, None] = None, + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, dropout_dim: int = 1, - dilation: Union[Sequence[int], int] = 1, + dilation: Sequence[int] | int = 1, groups: int = 1, bias: bool = True, conv_only: bool = False, is_transposed: bool = False, - padding: Union[Sequence[int], int, None] = None, - output_padding: Union[Sequence[int], int, None] = None, + padding: Sequence[int] | int | None = None, + output_padding: Sequence[int] | int | None = None, ) -> None: super().__init__() self.conv = monai.networks.blocks.Convolution( From 6f5a85d5ea873820a0505c1b42ff3a41bec95810 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 17:50:12 +0000 Subject: [PATCH 09/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 57 +++++++++++++++++-- tests/test_autoencoderkl_maisi.py | 24 ++++---- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index bf4dd831f9..3151ba93e6 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -31,11 +31,21 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. affine: Whether to use learnable affine parameters. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ - def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, debug: bool = True): + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + norm_float16: bool = False, + debug: bool = True, + ): super().__init__(num_groups, num_channels, eps, affine) + self.norm_float16 = norm_float16 self.debug = debug def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -53,7 +63,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: array = input[:, i : i + 1, ...].to(dtype=torch.float32) mean = array.mean([2, 3, 4, 5], keepdim=True) std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - inputs.append(array.sub_(mean).div_(std).to(dtype=torch.float16)) + if self.norm_float16: + inputs.append(array.sub_(mean).div_(std).to(dtype=torch.float16)) + else: + inputs.append(array.sub_(mean).div_(std)) del input torch.cuda.empty_cache() @@ -351,6 +364,7 @@ class MaisiResBlock(nn.Module): norm_eps: Epsilon for the normalization. out_channels: Number of output channels. num_splits: Number of splits for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -362,6 +376,7 @@ def __init__( norm_eps: float, out_channels: int, num_splits: int, + norm_float16: bool, debug: bool, ) -> None: super().__init__() @@ -369,7 +384,12 @@ def __init__( self.out_channels = in_channels if out_channels is None else out_channels self.norm1 = MaisiGroupNorm3D( - num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True, debug=debug + num_groups=norm_num_groups, + num_channels=in_channels, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + debug=debug, ) self.conv1 = MaisiConvolution( spatial_dims=spatial_dims, @@ -383,7 +403,12 @@ def __init__( debug=debug, ) self.norm2 = MaisiGroupNorm3D( - num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True, debug=debug + num_groups=norm_num_groups, + num_channels=out_channels, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + debug=debug, ) self.conv2 = MaisiConvolution( spatial_dims=spatial_dims, @@ -452,6 +477,7 @@ class MaisiEncoder(nn.Module): attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. num_splits: Number of splits for the input tensor. debug: Whether to print debug information. """ @@ -467,6 +493,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + norm_float16: bool, debug: bool, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, @@ -513,6 +540,7 @@ def __init__( norm_eps=norm_eps, out_channels=output_channel, num_splits=num_splits, + norm_float16=norm_float16, debug=debug, ) ) @@ -567,7 +595,12 @@ def __init__( blocks.append( MaisiGroupNorm3D( - num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True, debug=debug + num_groups=norm_num_groups, + num_channels=num_channels[-1], + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + debug=debug, ) ) blocks.append( @@ -610,6 +643,7 @@ class MaisiDecoder(nn.Module): use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -624,6 +658,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + norm_float16: bool, debug: bool, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, @@ -705,6 +740,7 @@ def __init__( norm_eps=norm_eps, out_channels=block_out_ch, num_splits=num_splits, + norm_float16=norm_float16, debug=debug, ) ) @@ -734,7 +770,12 @@ def __init__( blocks.append( MaisiGroupNorm3D( - num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True, debug=debug + num_groups=norm_num_groups, + num_channels=block_in_ch, + eps=norm_eps, + affine=True, + norm_float16=norm_float16, + debug=debug, ) ) blocks.append( @@ -780,6 +821,7 @@ class AutoencoderKlMaisi(AutoencoderKL): use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -800,6 +842,7 @@ def __init__( use_checkpointing: bool = False, use_convtranspose: bool = False, num_splits: int = 16, + norm_float16: bool = False, debug: bool = True, ) -> None: super().__init__( @@ -831,6 +874,7 @@ def __init__( with_nonlocal_attn=with_encoder_nonlocal_attn, use_flash_attention=use_flash_attention, num_splits=num_splits, + norm_float16=norm_float16, debug=debug, ) @@ -847,5 +891,6 @@ def __init__( use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, + norm_float16=norm_float16, debug=debug, ) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 63716b7c45..03579e64ab 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -40,11 +40,11 @@ "num_channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), - "num_res_blocks": 1, + "num_res_blocks": (1, 1, 1), "norm_num_groups": 4, "with_encoder_nonlocal_attn": False, "with_decoder_nonlocal_attn": False, - "num_splits": 1, + "num_splits": 4, "debug": False, }, (1, 1, 16, 16, 16), @@ -62,9 +62,9 @@ "num_channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, True), - "num_res_blocks": 1, + "num_res_blocks": (1, 1, 1), "norm_num_groups": 4, - "num_splits": 1, + "num_splits": 4, "debug": False, }, (1, 1, 16, 16, 16), @@ -112,9 +112,9 @@ def test_model_channels_not_multiple_of_norm_num_group(self): num_channels=(24, 24, 24), attention_levels=(False, False, False), latent_channels=8, - num_res_blocks=1, + num_res_blocks=(1, 1, 1), norm_num_groups=16, - num_splits=1, + num_splits=4, debug=False, ) @@ -127,9 +127,9 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): num_channels=(24, 24, 24), attention_levels=(False, False), latent_channels=8, - num_res_blocks=1, + num_res_blocks=(1, 1, 1), norm_num_groups=16, - num_splits=1, + num_splits=4, debug=False, ) @@ -142,9 +142,9 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): num_channels=(24, 24, 24), attention_levels=(False, False, False), latent_channels=8, - num_res_blocks=(8, 8), + num_res_blocks=(8, 8, 8), norm_num_groups=16, - num_splits=1, + num_splits=4, debug=False, ) @@ -233,9 +233,9 @@ def test_compatibility_with_monai_generative(self): num_channels=(4, 4, 4), latent_channels=4, attention_levels=(False, False, True), - num_res_blocks=1, + num_res_blocks=(1, 1, 1), norm_num_groups=4, - num_splits=1, + num_splits=4, debug=False, ).to(device) From 9da0710b8fddf1b93848f37546cfcf44994e3832 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 19 Jun 2024 18:34:05 +0000 Subject: [PATCH 10/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 88 +++++++++++-------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 3e1a0da16d..12d4fcc645 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -113,6 +113,7 @@ class MaisiConvolution(nn.Module): in_channels: Number of input channels. out_channels: Number of output channels. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. debug: Whether to print debug information. Additional arguments for the convolution operation. https://docs.monai.io/en/stable/networks.html#convolution @@ -124,6 +125,7 @@ def __init__( in_channels: int, out_channels: int, num_splits: int, + dim_split: int, debug: bool, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, @@ -161,8 +163,8 @@ def __init__( output_padding=output_padding, ) - self.tp_dim = 1 - self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides + self.dim_split = dim_split + self.stride = strides[self.dim_split] if isinstance(strides, list) else strides self.num_splits = num_splits self.debug = debug @@ -171,7 +173,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.debug: print("num_splits:", num_splits) - l = x.size(self.tp_dim + 2) + l = x.size(self.dim_split + 2) split_size = l // num_splits padding = 3 @@ -181,9 +183,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print("padding:", padding) overlaps = [0] + [padding] * (num_splits - 1) - last_padding = x.size(self.tp_dim + 2) % split_size + last_padding = x.size(self.dim_split + 2) % split_size - if self.tp_dim == 0: + if self.dim_split == 0: splits = [ x[ :, @@ -196,7 +198,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] for i in range(num_splits) ] - elif self.tp_dim == 1: + elif self.dim_split == 1: splits = [ x[ :, @@ -209,7 +211,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] for i in range(num_splits) ] - elif self.tp_dim == 2: + elif self.dim_split == 2: splits = [ x[ :, @@ -238,20 +240,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: split_size_out = split_size padding_s = padding - non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 - if outputs[0].size(non_tp_dim + 2) // splits[0].size(non_tp_dim + 2) == 2: + non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0 + if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2: split_size_out *= 2 padding_s *= 2 - if self.tp_dim == 0: + if self.dim_split == 0: outputs[0] = outputs[0][:, :, :split_size_out, :, :] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] - elif self.tp_dim == 1: + elif self.dim_split == 1: outputs[0] = outputs[0][:, :, :, :split_size_out, :] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] - elif self.tp_dim == 2: + elif self.dim_split == 2: outputs[0] = outputs[0][:, :, :, :, :split_size_out] for i in range(1, num_splits): outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] @@ -261,13 +263,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) if max(outputs[0].size()) < 500: - x = torch.cat(outputs, dim=self.tp_dim + 2) + x = torch.cat(outputs, dim=self.dim_split + 2) else: x = outputs[0].clone().to("cpu", non_blocking=True) outputs[0] = 0 torch.cuda.empty_cache() for k in range(len(outputs) - 1): - x = torch.cat((x, outputs[k + 1].cpu()), dim=self.tp_dim + 2) + x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2) outputs[k + 1] = 0 torch.cuda.empty_cache() gc.collect() @@ -289,10 +291,13 @@ class MaisiUpsample(nn.Module): spatial_dims: Number of spatial dimensions (1D, 2D, 3D). in_channels: Number of input channels to the layer. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. + num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. + debug: Whether to print debug information. """ def __init__( - self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int, debug: bool + self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int, dim_split: int, debug: bool ) -> None: super().__init__() self.conv = MaisiConvolution( @@ -305,6 +310,7 @@ def __init__( conv_only=True, is_transposed=use_convtranspose, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) self.use_convtranspose = use_convtranspose @@ -328,10 +334,11 @@ class MaisiDownsample(nn.Module): spatial_dims: Number of spatial dimensions (1D, 2D, 3D). in_channels: Number of input channels. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. debug: Whether to print debug information. """ - def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, debug: bool) -> None: + def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, debug: bool) -> None: super().__init__() self.pad = (0, 1) * spatial_dims self.conv = MaisiConvolution( @@ -343,6 +350,7 @@ def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, debug: padding=0, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) @@ -364,6 +372,7 @@ class MaisiResBlock(nn.Module): norm_eps: Epsilon for the normalization. out_channels: Number of output channels. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -376,6 +385,7 @@ def __init__( norm_eps: float, out_channels: int, num_splits: int, + dim_split: int, norm_float16: bool, debug: bool, ) -> None: @@ -400,6 +410,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) self.norm2 = MaisiGroupNorm3D( @@ -419,6 +430,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) @@ -432,6 +444,7 @@ def __init__( padding=0, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) if self.in_channels != self.out_channels @@ -479,6 +492,7 @@ class MaisiEncoder(nn.Module): use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. debug: Whether to print debug information. """ @@ -493,21 +507,13 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + dim_split: int, norm_float16: bool, debug: bool, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.num_channels = num_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - self.num_splits = num_splits blocks = [] @@ -521,6 +527,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) ) @@ -531,7 +538,7 @@ def __init__( output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - for _ in range(self.num_res_blocks[i]): + for _ in range(num_res_blocks[i]): blocks.append( MaisiResBlock( spatial_dims=spatial_dims, @@ -540,6 +547,7 @@ def __init__( norm_eps=norm_eps, out_channels=output_channel, num_splits=num_splits, + dim_split=dim_split, norm_float16=norm_float16, debug=debug, ) @@ -559,7 +567,11 @@ def __init__( if not is_final_block: blocks.append( MaisiDownsample( - spatial_dims=spatial_dims, in_channels=input_channel, num_splits=num_splits, debug=debug + spatial_dims=spatial_dims, + in_channels=input_channel, + num_splits=num_splits, + dim_split=dim_split, + debug=debug, ) ) @@ -605,7 +617,7 @@ def __init__( ) blocks.append( MaisiConvolution( - spatial_dims=self.spatial_dims, + spatial_dims=spatial_dims, in_channels=num_channels[-1], out_channels=out_channels, strides=1, @@ -613,6 +625,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) ) @@ -643,6 +656,7 @@ class MaisiDecoder(nn.Module): use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -658,6 +672,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], num_splits: int, + dim_split: int, norm_float16: bool, debug: bool, with_nonlocal_attn: bool = True, @@ -665,15 +680,6 @@ def __init__( use_convtranspose: bool = False, ) -> None: super().__init__() - self.spatial_dims = spatial_dims - self.num_channels = num_channels - self.in_channels = in_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - self.num_splits = num_splits self.debug = debug reversed_block_out_channels = list(reversed(num_channels)) @@ -690,6 +696,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) ) @@ -740,6 +747,7 @@ def __init__( norm_eps=norm_eps, out_channels=block_out_ch, num_splits=num_splits, + dim_split=dim_split, norm_float16=norm_float16, debug=debug, ) @@ -764,6 +772,7 @@ def __init__( in_channels=block_in_ch, use_convtranspose=use_convtranspose, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) ) @@ -788,6 +797,7 @@ def __init__( padding=1, conv_only=True, num_splits=num_splits, + dim_split=dim_split, debug=debug, ) ) @@ -821,6 +831,7 @@ class AutoencoderKlMaisi(AutoencoderKL): use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. + dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -842,6 +853,7 @@ def __init__( use_checkpointing: bool = False, use_convtranspose: bool = False, num_splits: int = 16, + dim_split: int = 0, norm_float16: bool = False, debug: bool = True, ) -> None: @@ -874,6 +886,7 @@ def __init__( with_nonlocal_attn=with_encoder_nonlocal_attn, use_flash_attention=use_flash_attention, num_splits=num_splits, + dim_split=dim_split, norm_float16=norm_float16, debug=debug, ) @@ -891,6 +904,7 @@ def __init__( use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, + dim_split=dim_split, norm_float16=norm_float16, debug=debug, ) From 01d7cf99625faebfe80f37bb89e52a3cecfc62ae Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 17:07:59 +0000 Subject: [PATCH 11/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 7 ++- requirements-dev.txt | 1 + tests/test_autoencoderkl_maisi.py | 57 ++++--------------- tests/testing_data/data_config.json | 5 ++ 4 files changed, 23 insertions(+), 47 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 12d4fcc645..d846eb82a9 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -244,6 +244,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2: split_size_out *= 2 padding_s *= 2 + elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2: + split_size_out //= 2 + padding_s //= 2 if self.dim_split == 0: outputs[0] = outputs[0][:, :, :split_size_out, :, :] @@ -847,8 +850,8 @@ def __init__( latent_channels: int = 3, norm_num_groups: int = 32, norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, + with_encoder_nonlocal_attn: bool = False, + with_decoder_nonlocal_attn: bool = False, use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, diff --git a/requirements-dev.txt b/requirements-dev.txt index a8ba25966b..37e5917c6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,3 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub +monai-generative diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 03579e64ab..c79cb9b003 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -11,19 +11,15 @@ from __future__ import annotations -import os -import tempfile import unittest -from unittest import skipUnless import torch from parameterized import parameterized -from monai.apps import download_url from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config +from tests.utils import SkipIfBeforePyTorchVersion tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") @@ -44,12 +40,12 @@ "norm_num_groups": 4, "with_encoder_nonlocal_attn": False, "with_decoder_nonlocal_attn": False, - "num_splits": 4, + "num_splits": 2, "debug": False, }, - (1, 1, 16, 16, 16), - (1, 1, 16, 16, 16), - (1, 4, 4, 4, 4), + (1, 1, 32, 32, 32), + (1, 1, 32, 32, 32), + (1, 4, 8, 8, 8), ] ] @@ -64,12 +60,12 @@ "attention_levels": (False, False, True), "num_res_blocks": (1, 1, 1), "norm_num_groups": 4, - "num_splits": 4, + "num_splits": 2, "debug": False, }, - (1, 1, 16, 16, 16), - (1, 1, 16, 16, 16), - (1, 4, 4, 4, 4), + (1, 1, 32, 32, 32), + (1, 1, 32, 32, 32), + (1, 4, 8, 8, 8), ] ] @@ -114,7 +110,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): latent_channels=8, num_res_blocks=(1, 1, 1), norm_num_groups=16, - num_splits=4, + num_splits=2, debug=False, ) @@ -129,7 +125,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): latent_channels=8, num_res_blocks=(1, 1, 1), norm_num_groups=16, - num_splits=4, + num_splits=2, debug=False, ) @@ -144,7 +140,7 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): latent_channels=8, num_res_blocks=(8, 8, 8), norm_num_groups=16, - num_splits=4, + num_splits=2, debug=False, ) @@ -222,35 +218,6 @@ def test_shape_decode_convtranspose_and_checkpointing(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) - @skipUnless(has_einops, "Requires einops") - def test_compatibility_with_monai_generative(self): - # test loading weights from a model saved in MONAI Generative, version 0.2.3 - with skip_if_downloading_fails(): - net = AutoencoderKlMaisi( - spatial_dims=3, - in_channels=1, - out_channels=1, - num_channels=(4, 4, 4), - latent_channels=4, - attention_levels=(False, False, True), - num_res_blocks=(1, 1, 1), - norm_num_groups=4, - num_splits=4, - debug=False, - ).to(device) - - tmpdir = tempfile.mkdtemp() - key = "autoencoderkl_monai_generative_weights" - url = testing_data_config("models", key, "url") - hash_type = testing_data_config("models", key, "hash_type") - hash_val = testing_data_config("models", key, "hash_val") - filename = "autoencoderkl_monai_generative_weights.pt" - - weight_path = os.path.join(tmpdir, filename) - download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - - net.load_old_state_dict(torch.load(weight_path), verbose=False) - if __name__ == "__main__": unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index a570c787ba..4fda1d8c55 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -138,6 +138,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" + }, + "autoencoderkl_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", + "hash_type": "sha256", + "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" } }, "configs": { From 3654336d223316c0454dbb9cd2392419d9ccd8f2 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 17:37:42 +0000 Subject: [PATCH 12/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/__init__.py | 21 +++++++++++++++++++++ requirements-dev.txt | 1 - 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py index 1e97f89407..ef42d42730 100644 --- a/monai/apps/generation/maisi/__init__.py +++ b/monai/apps/generation/maisi/__init__.py @@ -8,3 +8,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +import subprocess +import sys + + +def install_and_import(package, package_fullname=None): + if package_fullname is None: + package_fullname = package + + try: + __import__(package) + except ImportError: + print(f"'{package}' is not installed. Installing now...") + subprocess.check_call([sys.executable, "-m", "pip", "install", package_fullname]) + print(f"'{package}' installation completed.") + __import__(package) + + +install_and_import("generative", "monai-generative") diff --git a/requirements-dev.txt b/requirements-dev.txt index 37e5917c6a..a8ba25966b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,4 +57,3 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub -monai-generative From f7cbba1bd0cbe1a226475c0b351f317780c49f19 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 18:07:41 +0000 Subject: [PATCH 13/37] update Signed-off-by: dongyang0122 --- .../apps/generation/maisi/networks/autoencoderkl_maisi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index d846eb82a9..dbb45a0484 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -20,6 +20,7 @@ from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock import monai +from monai.networks.blocks import Convolution class MaisiGroupNorm3D(nn.GroupNorm): @@ -143,7 +144,7 @@ def __init__( output_padding: Sequence[int] | int | None = None, ) -> None: super().__init__() - self.conv = monai.networks.blocks.Convolution( + self.conv = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -518,7 +519,7 @@ def __init__( ) -> None: super().__init__() - blocks = [] + blocks: list[nn.Module] = [] blocks.append( MaisiConvolution( @@ -687,7 +688,7 @@ def __init__( reversed_block_out_channels = list(reversed(num_channels)) - blocks = [] + blocks: list[nn.Module] = [] blocks.append( MaisiConvolution( From b820487c0f1209452377b60775c7a91e130362fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 18:08:10 +0000 Subject: [PATCH 14/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index dbb45a0484..486f418799 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -19,7 +19,6 @@ import torch.nn.functional as F from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock -import monai from monai.networks.blocks import Convolution From 69b14fd8feab64fde0ac3907a49243e8d977052b Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 18:09:45 +0000 Subject: [PATCH 15/37] update Signed-off-by: dongyang0122 --- tests/testing_data/data_config.json | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 4fda1d8c55..a570c787ba 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -138,11 +138,6 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" - }, - "autoencoderkl_monai_generative_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", - "hash_type": "sha256", - "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" } }, "configs": { From bee14ab77bf348174a83999b041fd6e876b4dd9c Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 18:19:48 +0000 Subject: [PATCH 16/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 486f418799..00b2b6a148 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -22,6 +22,12 @@ from monai.networks.blocks import Convolution +def _empty_cuda_cache(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return + + class MaisiGroupNorm3D(nn.GroupNorm): """ Custom 3D Group Normalization with optional debug output. @@ -69,7 +75,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: inputs.append(array.sub_(mean).div_(std)) del input - torch.cuda.empty_cache() + _empty_cuda_cache() input = ( torch.cat([inputs[k] for k in range(len(inputs))], dim=1) @@ -90,12 +96,12 @@ def _cat_inputs(self, inputs): input_type = inputs[0].device.type input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone() inputs[0] = 0 - torch.cuda.empty_cache() + _empty_cuda_cache() for k in range(len(inputs) - 1): input = torch.cat((input, inputs[k + 1].cpu()), dim=1) inputs[k + 1] = 0 - torch.cuda.empty_cache() + _empty_cuda_cache() gc.collect() if self.debug: @@ -230,7 +236,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print(f"splits {j + 1}/{len(splits)}:", splits[j].size()) del x - torch.cuda.empty_cache() + _empty_cuda_cache() outputs = [self.conv(split) for split in splits] @@ -270,18 +276,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: x = outputs[0].clone().to("cpu", non_blocking=True) outputs[0] = 0 - torch.cuda.empty_cache() + _empty_cuda_cache() for k in range(len(outputs) - 1): x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2) outputs[k + 1] = 0 - torch.cuda.empty_cache() + _empty_cuda_cache() gc.collect() if self.debug: print(f"MaisiConvolution cat: {k + 1}/{len(outputs) - 1}.") x = x.to("cuda", non_blocking=True) del outputs - torch.cuda.empty_cache() + _empty_cuda_cache() return x @@ -323,9 +329,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) x = F.interpolate(x, scale_factor=2.0, mode="trilinear") - torch.cuda.empty_cache() + _empty_cuda_cache() x = self.conv(x) - torch.cuda.empty_cache() + _empty_cuda_cache() return x @@ -456,24 +462,24 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) - torch.cuda.empty_cache() + _empty_cuda_cache() h = F.silu(h) - torch.cuda.empty_cache() + _empty_cuda_cache() h = self.conv1(h) - torch.cuda.empty_cache() + _empty_cuda_cache() h = self.norm2(h) - torch.cuda.empty_cache() + _empty_cuda_cache() h = F.silu(h) - torch.cuda.empty_cache() + _empty_cuda_cache() h = self.conv2(h) - torch.cuda.empty_cache() + _empty_cuda_cache() if self.in_channels != self.out_channels: x = self.nin_shortcut(x) - torch.cuda.empty_cache() + _empty_cuda_cache() return x + h @@ -638,7 +644,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: x = block(x) - torch.cuda.empty_cache() + _empty_cuda_cache() return x @@ -810,7 +816,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: x = block(x) - torch.cuda.empty_cache() + _empty_cuda_cache() return x From e5217230f0166db8dcbc56b85bdae689602c8273 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 20:55:32 +0000 Subject: [PATCH 17/37] fix output type Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 00b2b6a148..c9ef3b75f8 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -326,13 +326,20 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_convtranspose: - return self.conv(x) + x = self.conv(x) + + if isinstance(x, torch.Tensor): + return x + return torch.tensor(x) x = F.interpolate(x, scale_factor=2.0, mode="trilinear") _empty_cuda_cache() x = self.conv(x) _empty_cuda_cache() - return x + + if isinstance(x, torch.Tensor): + return x + return torch.tensor(x) class MaisiDownsample(nn.Module): @@ -481,7 +488,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.nin_shortcut(x) _empty_cuda_cache() - return x + h + out = x + h + if isinstance(out, torch.Tensor): + return out + return torch.tensor(out) class MaisiEncoder(nn.Module): From 5aaede3c9cce15dd612b52f20055d263a42c60d8 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 22:12:13 +0000 Subject: [PATCH 18/37] update for loop indexing Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 64 ++++--------------- 1 file changed, 13 insertions(+), 51 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index c9ef3b75f8..62a60e882c 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -191,45 +191,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: overlaps = [0] + [padding] * (num_splits - 1) last_padding = x.size(self.dim_split + 2) % split_size - if self.dim_split == 0: - splits = [ - x[ - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - :, - :, - ] - for i in range(num_splits) - ] - elif self.dim_split == 1: - splits = [ - x[ - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - :, - ] - for i in range(num_splits) - ] - elif self.dim_split == 2: - splits = [ - x[ - :, - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - ] - for i in range(num_splits) - ] + slices = [slice(None)] * 5 + splits: list[torch.Tensor] = [] + for i in range(num_splits): + slices[self.dim_split + 2] = slice( + i * split_size - overlaps[i], (i + 1) * split_size + (padding if i != num_splits - 1 else last_padding) + ) + splits.append(x[tuple(slices)]) if self.debug: for j in range(len(splits)): @@ -254,18 +222,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: split_size_out //= 2 padding_s //= 2 - if self.dim_split == 0: - outputs[0] = outputs[0][:, :, :split_size_out, :, :] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] - elif self.dim_split == 1: - outputs[0] = outputs[0][:, :, :, :split_size_out, :] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] - elif self.dim_split == 2: - outputs[0] = outputs[0][:, :, :, :, :split_size_out] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] + slices = [slice(None)] * 5 + for i in range(num_splits): + slices[self.dim_split + 2] = ( + slice(None, split_size_out) if i == 0 else slice(padding_s, padding_s + split_size_out) + ) + outputs[i] = outputs[i][tuple(slices)] if self.debug: for i in range(num_splits): From 285f19c0a6b13f6202d9e6ab6889b595341bd806 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 24 Jun 2024 15:54:07 +0000 Subject: [PATCH 19/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/__init__.py | 21 ------------------- .../maisi/networks/autoencoderkl_maisi.py | 6 +++++- requirements-dev.txt | 1 + tests/test_autoencoderkl_maisi.py | 3 +++ 4 files changed, 9 insertions(+), 22 deletions(-) diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py index ef42d42730..1e97f89407 100644 --- a/monai/apps/generation/maisi/__init__.py +++ b/monai/apps/generation/maisi/__init__.py @@ -8,24 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -import subprocess -import sys - - -def install_and_import(package, package_fullname=None): - if package_fullname is None: - package_fullname = package - - try: - __import__(package) - except ImportError: - print(f"'{package}' is not installed. Installing now...") - subprocess.check_call([sys.executable, "-m", "pip", "install", package_fullname]) - print(f"'{package}' installation completed.") - __import__(package) - - -install_and_import("generative", "monai-generative") diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 62a60e882c..63850a0362 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -17,9 +17,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock from monai.networks.blocks import Convolution +from monai.utils import optional_import + +AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") +AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") +ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") def _empty_cuda_cache(): diff --git a/requirements-dev.txt b/requirements-dev.txt index a8ba25966b..37e5917c6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,3 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub +monai-generative diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index c79cb9b003..6b9b61222e 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -23,6 +24,7 @@ tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") +_, has_generative = optional_import("generative") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -75,6 +77,7 @@ CASES = CASES_NO_ATTENTION +@skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): From 0dc6196eb5ec786dcdd6d0d301b4591bc6f16053 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 24 Jun 2024 15:55:26 +0000 Subject: [PATCH 20/37] update Signed-off-by: dongyang0122 --- tests/test_autoencoderkl_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 6b9b61222e..50fdb48c7c 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -12,7 +12,6 @@ from __future__ import annotations import unittest -from unittest import skipUnless import torch from parameterized import parameterized @@ -77,7 +76,7 @@ CASES = CASES_NO_ATTENTION -@skipUnless(has_generative, "monai-generative required") +@unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): From 6e4bf9a5e50feb82a4cd4107094405c9fc2f5238 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 24 Jun 2024 16:32:36 +0000 Subject: [PATCH 21/37] update Signed-off-by: dongyang0122 --- tests/test_autoencoderkl_maisi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 50fdb48c7c..81b309f4c1 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -16,7 +16,6 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion @@ -25,6 +24,9 @@ _, has_einops = optional_import("einops") _, has_generative = optional_import("generative") +if has_generative: + from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") From 7279629de5b784a6f9e9e9febb3a2fd951170b70 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 24 Jun 2024 18:14:33 +0000 Subject: [PATCH 22/37] update Signed-off-by: dongyang0122 --- .../generation/maisi/networks/autoencoderkl_maisi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 63850a0362..4462e99ed1 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -12,7 +12,7 @@ from __future__ import annotations import gc -from typing import Sequence +from typing import TYPE_CHECKING, Sequence, cast import torch import torch.nn as nn @@ -26,6 +26,12 @@ ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") +if TYPE_CHECKING: + from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType +else: + AutoencoderKLType = cast(type, AutoencoderKL) + + def _empty_cuda_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -796,7 +802,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AutoencoderKlMaisi(AutoencoderKL): +class AutoencoderKlMaisi(AutoencoderKLType): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. From e881fce6afbb76a961527f906d80703d8778ebb0 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 26 Jun 2024 17:06:39 +0000 Subject: [PATCH 23/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 4462e99ed1..0edba294e7 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -46,7 +46,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_groups: Number of groups for the group norm. num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. - affine: Whether to use learnable affine parameters. + affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ @@ -58,7 +58,7 @@ def __init__( eps: float = 1e-5, affine: bool = True, norm_float16: bool = False, - debug: bool = True, + debug: bool = False, ): super().__init__(num_groups, num_channels, eps, affine) self.norm_float16 = norm_float16 @@ -846,7 +846,7 @@ def __init__( num_splits: int = 16, dim_split: int = 0, norm_float16: bool = False, - debug: bool = True, + debug: bool = False, ) -> None: super().__init__( spatial_dims, From f7ebd966606c220f58b03d4e36fce1dc22994586 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:07:07 +0000 Subject: [PATCH 24/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 0edba294e7..53692339f2 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -46,7 +46,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_groups: Number of groups for the group norm. num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. - affine: Whether to use learnable affine parameters, default to `True`. + affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. debug: Whether to print debug information. """ From 9bd8d91f1cc52460e876c354c4877756824ed371 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 26 Jun 2024 17:07:46 +0000 Subject: [PATCH 25/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 0edba294e7..187dc0e972 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -47,8 +47,8 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. affine: Whether to use learnable affine parameters, default to `True`. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. - debug: Whether to print debug information. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + debug: Whether to print debug information, default to `False`. """ def __init__( From 43ed14493aa6654a6821dd476075946eb1aee247 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:08:46 +0000 Subject: [PATCH 26/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 187dc0e972..327cce949f 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -46,7 +46,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_groups: Number of groups for the group norm. num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. - affine: Whether to use learnable affine parameters, default to `True`. + affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. debug: Whether to print debug information, default to `False`. """ From 5353a7649f0a003107fbcbcdd845d7ea8ceb9229 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 26 Jun 2024 17:30:10 +0000 Subject: [PATCH 27/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 187dc0e972..c05cb90e19 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -46,7 +46,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_groups: Number of groups for the group norm. num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. - affine: Whether to use learnable affine parameters, default to `True`. + affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. debug: Whether to print debug information, default to `False`. """ @@ -80,9 +80,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: mean = array.mean([2, 3, 4, 5], keepdim=True) std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() if self.norm_float16: - inputs.append(array.sub_(mean).div_(std).to(dtype=torch.float16)) + inputs.append(((array - mean) / std).to(dtype=torch.float16)) else: - inputs.append(array.sub_(mean).div_(std)) + inputs.append((array - mean) / std) del input _empty_cuda_cache() From 08ab45939053b8d3899f9b861e1fd124fc3d3df0 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Wed, 26 Jun 2024 17:53:59 +0000 Subject: [PATCH 28/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index c05cb90e19..2696732630 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -361,8 +361,8 @@ class MaisiResBlock(nn.Module): out_channels: Number of output channels. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. - debug: Whether to print debug information. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + debug: Whether to print debug information, default to `False`. """ def __init__( @@ -374,8 +374,8 @@ def __init__( out_channels: int, num_splits: int, dim_split: int, - norm_float16: bool, - debug: bool, + norm_float16: bool = False, + debug: bool = False, ) -> None: super().__init__() self.in_channels = in_channels @@ -481,10 +481,10 @@ class MaisiEncoder(nn.Module): attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - debug: Whether to print debug information. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + debug: Whether to print debug information, default to `False`. """ def __init__( @@ -499,8 +499,8 @@ def __init__( attention_levels: Sequence[bool], num_splits: int, dim_split: int, - norm_float16: bool, - debug: bool, + norm_float16: bool = False, + debug: bool = False, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: @@ -648,8 +648,8 @@ class MaisiDecoder(nn.Module): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. - debug: Whether to print debug information. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + debug: Whether to print debug information, default to `False`. """ def __init__( @@ -664,8 +664,8 @@ def __init__( attention_levels: Sequence[bool], num_splits: int, dim_split: int, - norm_float16: bool, - debug: bool, + norm_float16: bool = False, + debug: bool = False, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, use_convtranspose: bool = False, @@ -823,8 +823,8 @@ class AutoencoderKlMaisi(AutoencoderKLType): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format. - debug: Whether to print debug information. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + debug: Whether to print debug information, default to `False`. """ def __init__( From 34dab58746196bf1a0d99f0f15d4d2f8ce4847be Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 00:03:28 +0000 Subject: [PATCH 29/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 129 ++++++++++-------- tests/test_autoencoderkl_maisi.py | 10 +- 2 files changed, 76 insertions(+), 63 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 2696732630..b38eb282dc 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -12,6 +12,7 @@ from __future__ import annotations import gc +import logging from typing import TYPE_CHECKING, Sequence, cast import torch @@ -32,6 +33,11 @@ AutoencoderKLType = cast(type, AutoencoderKL) +# Set up logging configuration +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + def _empty_cuda_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -40,7 +46,7 @@ def _empty_cuda_cache(): class MaisiGroupNorm3D(nn.GroupNorm): """ - Custom 3D Group Normalization with optional debug output. + Custom 3D Group Normalization with optional print_info output. Args: num_groups: Number of groups for the group norm. @@ -48,7 +54,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): eps: Epsilon value for numerical stability. affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. - debug: Whether to print debug information, default to `False`. + print_info: Whether to print information, default to `False`. """ def __init__( @@ -58,15 +64,15 @@ def __init__( eps: float = 1e-5, affine: bool = True, norm_float16: bool = False, - debug: bool = False, + print_info: bool = False, ): super().__init__(num_groups, num_channels, eps, affine) self.norm_float16 = norm_float16 - self.debug = debug + self.print_info = print_info def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.debug: - print("MaisiGroupNorm3D in", input.size()) + if self.print_info: + logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") if len(input.shape) != 5: raise ValueError("Expected a 5D tensor") @@ -97,8 +103,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.affine: input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1)) - if self.debug: - print("MaisiGroupNorm3D out", input.size()) + if self.print_info: + logger.info(f"MaisiGroupNorm3D with output size: {input.size()}") return input @@ -114,15 +120,15 @@ def _cat_inputs(self, inputs): _empty_cuda_cache() gc.collect() - if self.debug: - print(f"MaisiGroupNorm3D cat: {k + 1}/{len(inputs) - 1}.") + if self.print_info: + logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.") return input.to("cuda", non_blocking=True) if input_type == "cuda" else input class MaisiConvolution(nn.Module): """ - Convolutional layer with optional debug output and custom splitting mechanism. + Convolutional layer with optional print_info output and custom splitting mechanism. Args: spatial_dims: Number of spatial dimensions (1D, 2D, 3D). @@ -130,7 +136,7 @@ class MaisiConvolution(nn.Module): out_channels: Number of output channels. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - debug: Whether to print debug information. + print_info: Whether to print information. Additional arguments for the convolution operation. https://docs.monai.io/en/stable/networks.html#convolution """ @@ -142,7 +148,7 @@ def __init__( out_channels: int, num_splits: int, dim_split: int, - debug: bool, + print_info: bool, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", @@ -182,12 +188,12 @@ def __init__( self.dim_split = dim_split self.stride = strides[self.dim_split] if isinstance(strides, list) else strides self.num_splits = num_splits - self.debug = debug + self.print_info = print_info def forward(self, x: torch.Tensor) -> torch.Tensor: num_splits = self.num_splits - if self.debug: - print("num_splits:", num_splits) + if self.print_info: + logger.info(f"Number of splits: {num_splits}") l = x.size(self.dim_split + 2) split_size = l // num_splits @@ -195,8 +201,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: padding = 3 if padding % self.stride > 0: padding = (padding // self.stride + 1) * self.stride - if self.debug: - print("padding:", padding) + if self.print_info: + logger.info(f"Padding size: {padding}") overlaps = [0] + [padding] * (num_splits - 1) last_padding = x.size(self.dim_split + 2) % split_size @@ -209,18 +215,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) splits.append(x[tuple(slices)]) - if self.debug: + if self.print_info: for j in range(len(splits)): - print(f"splits {j + 1}/{len(splits)}:", splits[j].size()) + logger.info(f"split {j + 1}/{len(splits)} size: {splits[j].size()}") del x _empty_cuda_cache() outputs = [self.conv(split) for split in splits] - if self.debug: + if self.print_info: for j in range(len(outputs)): - print(f"outputs before {j + 1}/{len(outputs)}:", outputs[j].size()) + logger.info(f"output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") split_size_out = split_size padding_s = padding @@ -239,9 +245,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) outputs[i] = outputs[i][tuple(slices)] - if self.debug: + if self.print_info: for i in range(num_splits): - print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + logger.info(f"output {i + 1}/{len(outputs)} size after: {outputs[i].size()}") if max(outputs[0].size()) < 500: x = torch.cat(outputs, dim=self.dim_split + 2) @@ -254,8 +260,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: outputs[k + 1] = 0 _empty_cuda_cache() gc.collect() - if self.debug: - print(f"MaisiConvolution cat: {k + 1}/{len(outputs) - 1}.") + if self.print_info: + logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.") + x = x.to("cuda", non_blocking=True) del outputs @@ -274,11 +281,17 @@ class MaisiUpsample(nn.Module): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - debug: Whether to print debug information. + print_info: Whether to print information. """ def __init__( - self, spatial_dims: int, in_channels: int, use_convtranspose: bool, num_splits: int, dim_split: int, debug: bool + self, + spatial_dims: int, + in_channels: int, + use_convtranspose: bool, + num_splits: int, + dim_split: int, + print_info: bool, ) -> None: super().__init__() self.conv = MaisiConvolution( @@ -292,7 +305,7 @@ def __init__( is_transposed=use_convtranspose, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) self.use_convtranspose = use_convtranspose @@ -323,10 +336,10 @@ class MaisiDownsample(nn.Module): in_channels: Number of input channels. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - debug: Whether to print debug information. + print_info: Whether to print information. """ - def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, debug: bool) -> None: + def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, print_info: bool) -> None: super().__init__() self.pad = (0, 1) * spatial_dims self.conv = MaisiConvolution( @@ -339,7 +352,7 @@ def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_spl conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -362,7 +375,7 @@ class MaisiResBlock(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. - debug: Whether to print debug information, default to `False`. + print_info: Whether to print information, default to `False`. """ def __init__( @@ -375,7 +388,7 @@ def __init__( num_splits: int, dim_split: int, norm_float16: bool = False, - debug: bool = False, + print_info: bool = False, ) -> None: super().__init__() self.in_channels = in_channels @@ -387,7 +400,7 @@ def __init__( eps=norm_eps, affine=True, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) self.conv1 = MaisiConvolution( spatial_dims=spatial_dims, @@ -399,7 +412,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) self.norm2 = MaisiGroupNorm3D( num_groups=norm_num_groups, @@ -407,7 +420,7 @@ def __init__( eps=norm_eps, affine=True, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) self.conv2 = MaisiConvolution( spatial_dims=spatial_dims, @@ -419,7 +432,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) self.nin_shortcut = ( @@ -433,7 +446,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) if self.in_channels != self.out_channels else nn.Identity() @@ -484,7 +497,7 @@ class MaisiEncoder(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. - debug: Whether to print debug information, default to `False`. + print_info: Whether to print information, default to `False`. """ def __init__( @@ -500,7 +513,7 @@ def __init__( num_splits: int, dim_split: int, norm_float16: bool = False, - debug: bool = False, + print_info: bool = False, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: @@ -519,7 +532,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -540,7 +553,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) ) input_channel = output_channel @@ -562,7 +575,7 @@ def __init__( in_channels=input_channel, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -603,7 +616,7 @@ def __init__( eps=norm_eps, affine=True, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) ) blocks.append( @@ -617,7 +630,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -649,7 +662,7 @@ class MaisiDecoder(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. - debug: Whether to print debug information, default to `False`. + print_info: Whether to print information, default to `False`. """ def __init__( @@ -665,13 +678,13 @@ def __init__( num_splits: int, dim_split: int, norm_float16: bool = False, - debug: bool = False, + print_info: bool = False, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: super().__init__() - self.debug = debug + self.print_info = print_info reversed_block_out_channels = list(reversed(num_channels)) @@ -688,7 +701,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -740,7 +753,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) ) block_in_ch = block_out_ch @@ -764,7 +777,7 @@ def __init__( use_convtranspose=use_convtranspose, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -775,7 +788,7 @@ def __init__( eps=norm_eps, affine=True, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) ) blocks.append( @@ -789,7 +802,7 @@ def __init__( conv_only=True, num_splits=num_splits, dim_split=dim_split, - debug=debug, + print_info=print_info, ) ) @@ -824,7 +837,7 @@ class AutoencoderKlMaisi(AutoencoderKLType): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. - debug: Whether to print debug information, default to `False`. + print_info: Whether to print information, default to `False`. """ def __init__( @@ -846,7 +859,7 @@ def __init__( num_splits: int = 16, dim_split: int = 0, norm_float16: bool = False, - debug: bool = False, + print_info: bool = False, ) -> None: super().__init__( spatial_dims, @@ -879,7 +892,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) self.decoder = MaisiDecoder( @@ -897,5 +910,5 @@ def __init__( num_splits=num_splits, dim_split=dim_split, norm_float16=norm_float16, - debug=debug, + print_info=print_info, ) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 81b309f4c1..9b261e0b75 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -44,7 +44,7 @@ "with_encoder_nonlocal_attn": False, "with_decoder_nonlocal_attn": False, "num_splits": 2, - "debug": False, + "print_info": False, }, (1, 1, 32, 32, 32), (1, 1, 32, 32, 32), @@ -64,7 +64,7 @@ "num_res_blocks": (1, 1, 1), "norm_num_groups": 4, "num_splits": 2, - "debug": False, + "print_info": False, }, (1, 1, 32, 32, 32), (1, 1, 32, 32, 32), @@ -115,7 +115,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): num_res_blocks=(1, 1, 1), norm_num_groups=16, num_splits=2, - debug=False, + print_info=False, ) def test_model_num_channels_not_same_size_of_attention_levels(self): @@ -130,7 +130,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): num_res_blocks=(1, 1, 1), norm_num_groups=16, num_splits=2, - debug=False, + print_info=False, ) def test_model_num_channels_not_same_size_of_num_res_blocks(self): @@ -145,7 +145,7 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): num_res_blocks=(8, 8, 8), norm_num_groups=16, num_splits=2, - debug=False, + print_info=False, ) def test_shape_reconstruction(self): From b7d0f1b2c809c4f32b26151304ae7cc5c71913b5 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 00:09:07 +0000 Subject: [PATCH 30/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index b38eb282dc..f93b7150bb 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -93,11 +93,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: del input _empty_cuda_cache() - input = ( - torch.cat([inputs[k] for k in range(len(inputs))], dim=1) - if max(inputs[0].size()) < 500 - else self._cat_inputs(inputs) - ) + input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs) input = input.view(param_n, param_c, param_d, param_h, param_w) if self.affine: From 86a983e0afa39fbb41c7fd10376686ac805f63e7 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 13:43:16 +0000 Subject: [PATCH 31/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index f93b7150bb..5792a48781 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -213,7 +213,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.print_info: for j in range(len(splits)): - logger.info(f"split {j + 1}/{len(splits)} size: {splits[j].size()}") + logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}") del x _empty_cuda_cache() @@ -222,7 +222,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.print_info: for j in range(len(outputs)): - logger.info(f"output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") + logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") split_size_out = split_size padding_s = padding @@ -243,7 +243,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.print_info: for i in range(num_splits): - logger.info(f"output {i + 1}/{len(outputs)} size after: {outputs[i].size()}") + logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}") if max(outputs[0].size()) < 500: x = torch.cat(outputs, dim=self.dim_split + 2) From a712a9b5b6ce9789db449849db844273b17c9036 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 16:01:45 +0000 Subject: [PATCH 32/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 96 +++++++++++-------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 5792a48781..c9adad0b9b 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -186,28 +186,16 @@ def __init__( self.num_splits = num_splits self.print_info = print_info - def forward(self, x: torch.Tensor) -> torch.Tensor: - num_splits = self.num_splits - if self.print_info: - logger.info(f"Number of splits: {num_splits}") - - l = x.size(self.dim_split + 2) - split_size = l // num_splits - - padding = 3 - if padding % self.stride > 0: - padding = (padding // self.stride + 1) * self.stride - if self.print_info: - logger.info(f"Padding size: {padding}") - - overlaps = [0] + [padding] * (num_splits - 1) + def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]: + overlaps = [0] + [padding] * (self.num_splits - 1) last_padding = x.size(self.dim_split + 2) % split_size slices = [slice(None)] * 5 splits: list[torch.Tensor] = [] - for i in range(num_splits): + for i in range(self.num_splits): slices[self.dim_split + 2] = slice( - i * split_size - overlaps[i], (i + 1) * split_size + (padding if i != num_splits - 1 else last_padding) + i * split_size - overlaps[i], + (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding), ) splits.append(x[tuple(slices)]) @@ -215,51 +203,75 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for j in range(len(splits)): logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}") - del x - _empty_cuda_cache() - - outputs = [self.conv(split) for split in splits] - - if self.print_info: - for j in range(len(outputs)): - logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") - - split_size_out = split_size - padding_s = padding - non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0 - if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2: - split_size_out *= 2 - padding_s *= 2 - elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2: - split_size_out //= 2 - padding_s //= 2 + return splits + def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor: slices = [slice(None)] * 5 - for i in range(num_splits): - slices[self.dim_split + 2] = ( - slice(None, split_size_out) if i == 0 else slice(padding_s, padding_s + split_size_out) - ) + for i in range(self.num_splits): + slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size) outputs[i] = outputs[i][tuple(slices)] if self.print_info: - for i in range(num_splits): + for i in range(self.num_splits): logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}") if max(outputs[0].size()) < 500: x = torch.cat(outputs, dim=self.dim_split + 2) else: x = outputs[0].clone().to("cpu", non_blocking=True) - outputs[0] = 0 + outputs[0] = torch.Tensor(0) _empty_cuda_cache() for k in range(len(outputs) - 1): x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2) - outputs[k + 1] = 0 + outputs[k + 1] = torch.Tensor(0) _empty_cuda_cache() gc.collect() if self.print_info: logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.") x = x.to("cuda", non_blocking=True) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.print_info: + logger.info(f"Number of splits: {self.num_splits}") + + # compute size of splits + l = x.size(self.dim_split + 2) + split_size = l // self.num_splits + + # update padding length if necessary + padding = 3 + if padding % self.stride > 0: + padding = (padding // self.stride + 1) * self.stride + if self.print_info: + logger.info(f"Padding size: {padding}") + + # split tensor into a list of tensors + splits = self._split_tensor(x, split_size, padding) + + del x + _empty_cuda_cache() + + # convolution + outputs = [self.conv(split) for split in splits] + if self.print_info: + for j in range(len(outputs)): + logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}") + + # update size of splits and padding length for output + split_size_out = split_size + padding_s = padding + non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0 + if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2: + split_size_out *= 2 + padding_s *= 2 + elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2: + split_size_out //= 2 + padding_s //= 2 + + # concatenate list of tensors + x = self._concatenate_tensors(outputs, split_size_out, padding_s) del outputs _empty_cuda_cache() From 27e7b81c2db3f3408144654bf9e9136a0a24672c Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 16:12:41 +0000 Subject: [PATCH 33/37] update Signed-off-by: dongyang0122 --- tests/test_autoencoderkl_maisi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 9b261e0b75..e392f68e92 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -63,6 +63,8 @@ "attention_levels": (False, False, True), "num_res_blocks": (1, 1, 1), "norm_num_groups": 4, + "with_encoder_nonlocal_attn": True, + "with_decoder_nonlocal_attn": True, "num_splits": 2, "print_info": False, }, From 90937d09f4a7d3a48411d71353046bf295e13901 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Thu, 27 Jun 2024 18:16:41 +0000 Subject: [PATCH 34/37] update Signed-off-by: dongyang0122 --- tests/test_autoencoderkl_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index e392f68e92..3d31107905 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -141,10 +141,10 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): spatial_dims=3, in_channels=1, out_channels=1, - num_channels=(24, 24, 24), + num_channels=(23, 24, 25), attention_levels=(False, False, False), latent_channels=8, - num_res_blocks=(8, 8, 8), + num_res_blocks=(7, 8, 9), norm_num_groups=16, num_splits=2, print_info=False, From bf4c2ba5290e0ac4c653e181678607d055f01c05 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 16:10:09 +0000 Subject: [PATCH 35/37] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index c9adad0b9b..3ae389a80b 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -34,7 +34,6 @@ # Set up logging configuration -logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) From 27b3af60e55b1a756e84c06c49ebf51f27ae56a4 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 17:27:15 +0000 Subject: [PATCH 36/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 110 +++++++++++++----- tests/test_autoencoderkl_maisi.py | 4 +- 2 files changed, 81 insertions(+), 33 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 3ae389a80b..c302324ef7 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -21,6 +21,7 @@ from monai.networks.blocks import Convolution from monai.utils import optional_import +from monai.utils.type_conversion import convert_to_tensor AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") @@ -37,8 +38,8 @@ logger = logging.getLogger(__name__) -def _empty_cuda_cache(): - if torch.cuda.is_available(): +def _empty_cuda_cache(save_mem: bool) -> None: + if torch.cuda.is_available() and save_mem: torch.cuda.empty_cache() return @@ -54,6 +55,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): affine: Whether to use learnable affine parameters, default to `True`. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -64,10 +66,12 @@ def __init__( affine: bool = True, norm_float16: bool = False, print_info: bool = False, + save_mem: bool = True, ): super().__init__(num_groups, num_channels, eps, affine) self.norm_float16 = norm_float16 self.print_info = print_info + self.save_mem = save_mem def forward(self, input: torch.Tensor) -> torch.Tensor: if self.print_info: @@ -90,7 +94,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: inputs.append((array - mean) / std) del input - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs) @@ -107,12 +111,12 @@ def _cat_inputs(self, inputs): input_type = inputs[0].device.type input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone() inputs[0] = 0 - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) for k in range(len(inputs) - 1): input = torch.cat((input, inputs[k + 1].cpu()), dim=1) inputs[k + 1] = 0 - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) gc.collect() if self.print_info: @@ -132,6 +136,7 @@ class MaisiConvolution(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory. Additional arguments for the convolution operation. https://docs.monai.io/en/stable/networks.html#convolution """ @@ -144,6 +149,7 @@ def __init__( num_splits: int, dim_split: int, print_info: bool, + save_mem: bool, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", @@ -184,6 +190,7 @@ def __init__( self.stride = strides[self.dim_split] if isinstance(strides, list) else strides self.num_splits = num_splits self.print_info = print_info + self.save_mem = save_mem def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]: overlaps = [0] + [padding] * (self.num_splits - 1) @@ -219,11 +226,11 @@ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, pad else: x = outputs[0].clone().to("cpu", non_blocking=True) outputs[0] = torch.Tensor(0) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) for k in range(len(outputs) - 1): x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2) outputs[k + 1] = torch.Tensor(0) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) gc.collect() if self.print_info: logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.") @@ -250,7 +257,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: splits = self._split_tensor(x, split_size, padding) del x - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) # convolution outputs = [self.conv(split) for split in splits] @@ -273,7 +280,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._concatenate_tensors(outputs, split_size_out, padding_s) del outputs - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) return x @@ -289,6 +296,7 @@ class MaisiUpsample(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory. """ def __init__( @@ -299,6 +307,7 @@ def __init__( num_splits: int, dim_split: int, print_info: bool, + save_mem: bool, ) -> None: super().__init__() self.conv = MaisiConvolution( @@ -313,25 +322,24 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) self.use_convtranspose = use_convtranspose + self.save_mem = save_mem def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_convtranspose: x = self.conv(x) - - if isinstance(x, torch.Tensor): - return x - return torch.tensor(x) + x_tensor: torch.Tensor = convert_to_tensor(x) + return x_tensor x = F.interpolate(x, scale_factor=2.0, mode="trilinear") - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) x = self.conv(x) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) - if isinstance(x, torch.Tensor): - return x - return torch.tensor(x) + out_tensor: torch.Tensor = convert_to_tensor(x) + return out_tensor class MaisiDownsample(nn.Module): @@ -344,9 +352,12 @@ class MaisiDownsample(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. + save_mem: Whether to clean CUDA cache in order to save GPU memory. """ - def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, print_info: bool) -> None: + def __init__( + self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, print_info: bool, save_mem: bool + ) -> None: super().__init__() self.pad = (0, 1) * spatial_dims self.conv = MaisiConvolution( @@ -360,6 +371,7 @@ def __init__(self, spatial_dims: int, in_channels: int, num_splits: int, dim_spl num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -383,6 +395,7 @@ class MaisiResBlock(nn.Module): dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -396,10 +409,12 @@ def __init__( dim_split: int, norm_float16: bool = False, print_info: bool = False, + save_mem: bool = True, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels + self.save_mem = save_mem self.norm1 = MaisiGroupNorm3D( num_groups=norm_num_groups, @@ -408,6 +423,7 @@ def __init__( affine=True, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) self.conv1 = MaisiConvolution( spatial_dims=spatial_dims, @@ -420,6 +436,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) self.norm2 = MaisiGroupNorm3D( num_groups=norm_num_groups, @@ -428,6 +445,7 @@ def __init__( affine=True, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) self.conv2 = MaisiConvolution( spatial_dims=spatial_dims, @@ -440,6 +458,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) self.nin_shortcut = ( @@ -454,6 +473,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) if self.in_channels != self.out_channels else nn.Identity() @@ -461,29 +481,28 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) h = F.silu(h) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) h = self.conv1(h) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) h = self.norm2(h) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) h = F.silu(h) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) h = self.conv2(h) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) out = x + h - if isinstance(out, torch.Tensor): - return out - return torch.tensor(out) + out_tensor: torch.Tensor = convert_to_tensor(out) + return out_tensor class MaisiEncoder(nn.Module): @@ -505,6 +524,7 @@ class MaisiEncoder(nn.Module): dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -521,11 +541,22 @@ def __init__( dim_split: int, norm_float16: bool = False, print_info: bool = False, + save_mem: bool = True, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, ) -> None: super().__init__() + # Check if attention_levels and num_channels have the same size + if len(attention_levels) != len(num_channels): + raise ValueError("attention_levels and num_channels must have the same size") + + # Check if num_res_blocks and num_channels have the same size + if len(num_res_blocks) != len(num_channels): + raise ValueError("num_res_blocks and num_channels must have the same size") + + self.save_mem = save_mem + blocks: list[nn.Module] = [] blocks.append( @@ -540,6 +571,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -561,6 +593,7 @@ def __init__( dim_split=dim_split, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) ) input_channel = output_channel @@ -583,6 +616,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -624,6 +658,7 @@ def __init__( affine=True, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) ) blocks.append( @@ -638,6 +673,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -646,7 +682,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: x = block(x) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) return x @@ -670,6 +706,7 @@ class MaisiDecoder(nn.Module): dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -686,12 +723,14 @@ def __init__( dim_split: int, norm_float16: bool = False, print_info: bool = False, + save_mem: bool = True, with_nonlocal_attn: bool = True, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: super().__init__() self.print_info = print_info + self.save_mem = save_mem reversed_block_out_channels = list(reversed(num_channels)) @@ -709,6 +748,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -761,6 +801,7 @@ def __init__( dim_split=dim_split, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) ) block_in_ch = block_out_ch @@ -785,6 +826,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -796,6 +838,7 @@ def __init__( affine=True, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) ) blocks.append( @@ -810,6 +853,7 @@ def __init__( num_splits=num_splits, dim_split=dim_split, print_info=print_info, + save_mem=save_mem, ) ) @@ -818,7 +862,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: x = block(x) - _empty_cuda_cache() + _empty_cuda_cache(self.save_mem) return x @@ -845,6 +889,7 @@ class AutoencoderKlMaisi(AutoencoderKLType): dim_split: Dimension of splitting for the input tensor. norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -867,6 +912,7 @@ def __init__( dim_split: int = 0, norm_float16: bool = False, print_info: bool = False, + save_mem: bool = True, ) -> None: super().__init__( spatial_dims, @@ -900,6 +946,7 @@ def __init__( dim_split=dim_split, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) self.decoder = MaisiDecoder( @@ -918,4 +965,5 @@ def __init__( dim_split=dim_split, norm_float16=norm_float16, print_info=print_info, + save_mem=save_mem, ) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 3d31107905..e88dc469c9 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -141,10 +141,10 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): spatial_dims=3, in_channels=1, out_channels=1, - num_channels=(23, 24, 25), + num_channels=(24, 24), attention_levels=(False, False, False), latent_channels=8, - num_res_blocks=(7, 8, 9), + num_res_blocks=(8, 8, 8), norm_num_groups=16, num_splits=2, print_info=False, From 5ad7deefd9093ec58cb2d3f7f491b4c343becad8 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 17:49:31 +0000 Subject: [PATCH 37/37] update Signed-off-by: dongyang0122 --- .../maisi/networks/autoencoderkl_maisi.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index c302324ef7..533da32fa0 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -136,7 +136,7 @@ class MaisiConvolution(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. - save_mem: Whether to clean CUDA cache in order to save GPU memory. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. Additional arguments for the convolution operation. https://docs.monai.io/en/stable/networks.html#convolution """ @@ -149,7 +149,7 @@ def __init__( num_splits: int, dim_split: int, print_info: bool, - save_mem: bool, + save_mem: bool = True, strides: Sequence[int] | int = 1, kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", @@ -296,7 +296,7 @@ class MaisiUpsample(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. - save_mem: Whether to clean CUDA cache in order to save GPU memory. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( @@ -307,7 +307,7 @@ def __init__( num_splits: int, dim_split: int, print_info: bool, - save_mem: bool, + save_mem: bool = True, ) -> None: super().__init__() self.conv = MaisiConvolution( @@ -352,11 +352,17 @@ class MaisiDownsample(nn.Module): num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. print_info: Whether to print information. - save_mem: Whether to clean CUDA cache in order to save GPU memory. + save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ def __init__( - self, spatial_dims: int, in_channels: int, num_splits: int, dim_split: int, print_info: bool, save_mem: bool + self, + spatial_dims: int, + in_channels: int, + num_splits: int, + dim_split: int, + print_info: bool, + save_mem: bool = True, ) -> None: super().__init__() self.pad = (0, 1) * spatial_dims