diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 276bf86f..f0187d5c 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -45,21 +45,38 @@ class Upsample(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ - def __init__(self, spatial_dims: int, in_channels: int) -> None: + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: super().__init__() - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # https://github.com/pytorch/pytorch/issues/86679 dtype = x.dtype @@ -450,6 +467,7 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ def __init__( @@ -464,6 +482,7 @@ def __init__( attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, use_flash_attention: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -553,7 +572,9 @@ def __init__( ) if not is_final_block: - blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch)) + blocks.append( + Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -595,6 +616,8 @@ class AutoencoderKL(nn.Module): with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpointing: if True, use activation checkpointing to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ def __init__( @@ -611,6 +634,8 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() @@ -658,6 +683,7 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, @@ -687,6 +713,7 @@ def __init__( conv_only=True, ) self.latent_channels = latent_channels + self.use_checkpointing = use_checkpointing def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -696,7 +723,10 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x: BxCx[SPATIAL DIMS] tensor """ - h = self.encoder(x) + if self.use_checkpointing: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) z_mu = self.quant_conv_mu(h) z_log_var = self.quant_conv_log_sigma(h) @@ -747,7 +777,10 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: decoded image tensor """ z = self.post_quant_conv(z) - dec = self.decoder(z) + if self.use_checkpointing: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) return dec def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 39335587..74173067 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -297,6 +297,7 @@ class VQVAE(nn.Module): dropout: dropout ratio. output_act: activation type and arguments for the output. ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. """ def __init__( @@ -321,6 +322,7 @@ def __init__( act: tuple | str | None = Act.RELU, output_act: tuple | str | None = None, ddp_sync: bool = True, + use_checkpointing: bool = False, ): super().__init__() @@ -330,6 +332,7 @@ def __init__( self.num_channels = num_channels self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing if isinstance(num_res_channels, int): num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) @@ -412,14 +415,20 @@ def __init__( ) def encode(self, images: torch.Tensor) -> torch.Tensor: - return self.encoder(images) + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + return self.encoder(images) def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x_loss, x = self.quantizer(encodings) return x, x_loss def decode(self, quantizations: torch.Tensor) -> torch.Tensor: - return self.decoder(quantizations) + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + return self.decoder(quantizations) def index_quantize(self, images: torch.Tensor) -> torch.Tensor: return self.quantizer.quantize(self.encode(images=images)) @@ -434,7 +443,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return reconstruction, quantization_losses def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: - z = self.encoder(x) + z = self.encode(x) e, _ = self.quantize(z) return e diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 3c6ca000..6a2e9820 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -143,6 +143,19 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) + @parameterized.expand(CASES) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + # def test_script(self): # input_param, input_shape, _, _ = CASES[0] # net = AutoencoderKL(**input_param) @@ -195,6 +208,15 @@ def test_shape_reconstruction(self): result = net.reconstruct(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + def test_shape_encode(self): input_param, input_shape, _, expected_latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -203,6 +225,16 @@ def test_shape_encode(self): self.assertEqual(result[0].shape, expected_latent_shape) self.assertEqual(result[1].shape, expected_latent_shape) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + def test_shape_sampling(self): input_param, _, _, expected_latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -212,6 +244,17 @@ def test_shape_sampling(self): ) self.assertEqual(result.shape, expected_latent_shape) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + def test_shape_decode(self): input_param, expected_input_shape, _, latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -219,6 +262,15 @@ def test_shape_decode(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 9d9ccade..d2e17636 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -18,7 +18,6 @@ from parameterized import parameterized from generative.networks.nets.vqvae import VQVAE -from tests.utils import test_script_save TEST_CASES = [ [ @@ -113,22 +112,36 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) - def test_script(self): - net = VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - num_res_layers=1, - num_channels=(8, 8), - num_res_channels=(8, 8), - num_embeddings=16, - embedding_dim=8, - ddp_sync=False, - ) - test_data = torch.randn(1, 1, 16, 16) - test_script_save(net, test_data) + @parameterized.expand(TEST_CASES) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # num_channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) def test_num_channels_not_same_size_of_num_res_channels(self): with self.assertRaises(ValueError):