From 52b04de04211984eb848f0705eb93977bda4cc99 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 18:04:17 -0400 Subject: [PATCH 1/3] support convtranspose checkpointing --- generative/networks/nets/autoencoderkl.py | 146 +++++++++++++--------- 1 file changed, 89 insertions(+), 57 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 276bf86f..aede6dbb 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -45,21 +45,38 @@ class Upsample(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ - def __init__(self, spatial_dims: int, in_channels: int) -> None: + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: super().__init__() - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # https://github.com/pytorch/pytorch/issues/86679 dtype = x.dtype @@ -120,7 +137,7 @@ class ResBlock(nn.Module): """ def __init__( - self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int ) -> None: super().__init__() self.in_channels = in_channels @@ -191,13 +208,13 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, ) -> None: super().__init__() self.use_flash_attention = use_flash_attention @@ -232,7 +249,7 @@ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: return x def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: query = query.contiguous() key = key.contiguous() @@ -313,17 +330,17 @@ class Encoder(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - num_channels: Sequence[int], - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -450,20 +467,22 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ def __init__( - self, - spatial_dims: int, - num_channels: Sequence[int], - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -553,7 +572,8 @@ def __init__( ) if not is_final_block: - blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch)) + blocks.append( + Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose)) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -595,22 +615,26 @@ class AutoencoderKL(nn.Module): with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpointing if True, use activation checkpointing to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ def __init__( - self, - spatial_dims: int, - in_channels: int = 1, - out_channels: int = 1, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - latent_channels: int = 3, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() @@ -658,6 +682,7 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, @@ -687,6 +712,7 @@ def __init__( conv_only=True, ) self.latent_channels = latent_channels + self.use_checkpointing = use_checkpointing def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -696,7 +722,10 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x: BxCx[SPATIAL DIMS] tensor """ - h = self.encoder(x) + if self.use_checkpointing: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) z_mu = self.quant_conv_mu(h) z_log_var = self.quant_conv_log_sigma(h) @@ -747,7 +776,10 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: decoded image tensor """ z = self.post_quant_conv(z) - dec = self.decoder(z) + if self.use_checkpointing: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) return dec def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: From 668767f746b10fea77eb27bd004cea4bdf2b15f2 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 1 Aug 2023 18:14:58 -0400 Subject: [PATCH 2/3] add unittests --- generative/networks/nets/autoencoderkl.py | 99 ++++++++++++----------- tests/test_autoencoderkl.py | 47 +++++++++++ 2 files changed, 97 insertions(+), 49 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index aede6dbb..9d89289e 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -59,7 +59,7 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) kernel_size=3, padding=1, conv_only=True, - is_transposed=True + is_transposed=True, ) else: self.conv = Convolution( @@ -137,7 +137,7 @@ class ResBlock(nn.Module): """ def __init__( - self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int ) -> None: super().__init__() self.in_channels = in_channels @@ -208,13 +208,13 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, ) -> None: super().__init__() self.use_flash_attention = use_flash_attention @@ -249,7 +249,7 @@ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: return x def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: query = query.contiguous() key = key.contiguous() @@ -330,17 +330,17 @@ class Encoder(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - num_channels: Sequence[int], - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -471,18 +471,18 @@ class Decoder(nn.Module): """ def __init__( - self, - spatial_dims: int, - num_channels: Sequence[int], - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_convtranspose: bool = False, + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -573,7 +573,8 @@ def __init__( if not is_final_block: blocks.append( - Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose)) + Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -620,21 +621,21 @@ class AutoencoderKL(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int = 1, - out_channels: int = 1, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - latent_channels: int = 3, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_checkpointing: bool = False, - use_convtranspose: bool = False, + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, ) -> None: super().__init__() diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 3c6ca000..726ddda7 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -143,6 +143,18 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) + @parameterized.expand(CASES) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + # def test_script(self): # input_param, input_shape, _, _ = CASES[0] # net = AutoencoderKL(**input_param) @@ -195,6 +207,14 @@ def test_shape_reconstruction(self): result = net.reconstruct(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + def test_shape_encode(self): input_param, input_shape, _, expected_latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -203,6 +223,15 @@ def test_shape_encode(self): self.assertEqual(result[0].shape, expected_latent_shape) self.assertEqual(result[1].shape, expected_latent_shape) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + def test_shape_sampling(self): input_param, _, _, expected_latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -212,6 +241,16 @@ def test_shape_sampling(self): ) self.assertEqual(result.shape, expected_latent_shape) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + def test_shape_decode(self): input_param, expected_input_shape, _, latent_shape = CASES[0] net = AutoencoderKL(**input_param).to(device) @@ -219,6 +258,14 @@ def test_shape_decode(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + if __name__ == "__main__": unittest.main() From d89e70306151511aa882825bb15f3e18a7f17756 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 10 Aug 2023 12:28:11 -0400 Subject: [PATCH 3/3] support checkpoint in vqvae and fix unittests --- generative/networks/nets/autoencoderkl.py | 2 +- generative/networks/nets/vqvae.py | 15 ++++++-- tests/test_autoencoderkl.py | 5 +++ tests/test_vqvae.py | 47 +++++++++++++++-------- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 9d89289e..f0187d5c 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -616,7 +616,7 @@ class AutoencoderKL(nn.Module): with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - use_checkpointing if True, use activation checkpointing to save memory. + use_checkpointing: if True, use activation checkpointing to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 39335587..74173067 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -297,6 +297,7 @@ class VQVAE(nn.Module): dropout: dropout ratio. output_act: activation type and arguments for the output. ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. """ def __init__( @@ -321,6 +322,7 @@ def __init__( act: tuple | str | None = Act.RELU, output_act: tuple | str | None = None, ddp_sync: bool = True, + use_checkpointing: bool = False, ): super().__init__() @@ -330,6 +332,7 @@ def __init__( self.num_channels = num_channels self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing if isinstance(num_res_channels, int): num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) @@ -412,14 +415,20 @@ def __init__( ) def encode(self, images: torch.Tensor) -> torch.Tensor: - return self.encoder(images) + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + return self.encoder(images) def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x_loss, x = self.quantizer(encodings) return x, x_loss def decode(self, quantizations: torch.Tensor) -> torch.Tensor: - return self.decoder(quantizations) + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + return self.decoder(quantizations) def index_quantize(self, images: torch.Tensor) -> torch.Tensor: return self.quantizer.quantize(self.encode(images=images)) @@ -434,7 +443,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return reconstruction, quantization_losses def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: - z = self.encoder(x) + z = self.encode(x) e, _ = self.quantize(z) return e diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 726ddda7..6a2e9820 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -147,6 +147,7 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s def test_shape_with_convtranspose_and_checkpointing( self, input_param, input_shape, expected_shape, expected_latent_shape ): + input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -209,6 +210,7 @@ def test_shape_reconstruction(self): def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -225,6 +227,7 @@ def test_shape_encode(self): def test_shape_encode_with_convtranspose_and_checkpointing(self): input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -243,6 +246,7 @@ def test_shape_sampling(self): def test_shape_sampling_convtranspose_and_checkpointing(self): input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -260,6 +264,7 @@ def test_shape_decode(self): def test_shape_decode_convtranspose_and_checkpointing(self): input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 9d9ccade..d2e17636 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -18,7 +18,6 @@ from parameterized import parameterized from generative.networks.nets.vqvae import VQVAE -from tests.utils import test_script_save TEST_CASES = [ [ @@ -113,22 +112,36 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) - def test_script(self): - net = VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - num_res_layers=1, - num_channels=(8, 8), - num_res_channels=(8, 8), - num_embeddings=16, - embedding_dim=8, - ddp_sync=False, - ) - test_data = torch.randn(1, 1, 16, 16) - test_script_save(net, test_data) + @parameterized.expand(TEST_CASES) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # num_channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) def test_num_channels_not_same_size_of_num_res_channels(self): with self.assertRaises(ValueError):