diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index f27f73ec60..a52274b24a 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -13,25 +13,17 @@ import gc import logging -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks import Convolution -from monai.utils import optional_import +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL from monai.utils.type_conversion import convert_to_tensor -AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") -AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") -ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") - -if TYPE_CHECKING: - from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType -else: - AutoencoderKLType = cast(type, AutoencoderKL) - # Set up logging configuration logger = logging.getLogger(__name__) @@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module): in_channels: Number of input channels. num_channels: Sequence of block output channels. out_channels: Number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. @@ -547,6 +541,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -603,11 +599,13 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -626,7 +624,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -636,16 +634,18 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module): num_channels: Sequence of block output channels. in_channels: Number of channels in the bottom layer (latent space) of the autoencoder. out_channels: Number of output channels. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. @@ -729,6 +731,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: @@ -758,7 +762,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -767,16 +771,18 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -812,11 +818,13 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AutoencoderKlMaisi(AutoencoderKLType): +class AutoencoderKlMaisi(AutoencoderKL): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. @@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType): norm_eps: Epsilon for the normalization. with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. @@ -909,6 +919,8 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = False, with_decoder_nonlocal_attn: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, @@ -930,12 +942,14 @@ def __init__( norm_eps, with_encoder_nonlocal_attn, with_decoder_nonlocal_attn, - use_flash_attention, use_checkpointing, use_convtranspose, + include_fc, + use_combined_linear, + use_flash_attention, ) - self.encoder = MaisiEncoder( + self.encoder: nn.Module = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -945,6 +959,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, num_splits=num_splits, dim_split=dim_split, @@ -953,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = MaisiDecoder( + self.decoder: nn.Module = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, @@ -963,6 +979,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 836027796f..af191e748b 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -532,7 +532,7 @@ def __init__( "`num_channels`." ) - self.encoder = Encoder( + self.encoder: nn.Module = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, channels=channels, @@ -546,7 +546,7 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - self.decoder = Decoder( + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, in_channels=latent_channels, diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 392a3d7db2..0e9f427fb6 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -16,16 +16,13 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -79,7 +76,6 @@ CASES = CASES_NO_ATTENTION -@unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES)