Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,38 @@ class Upsample(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels to the layer.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(self, spatial_dims: int, in_channels: int) -> None:
def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None:
super().__init__()
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
if use_convtranspose:
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=2,
kernel_size=3,
padding=1,
conv_only=True,
is_transposed=True,
)
else:
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
self.use_convtranspose = use_convtranspose

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_convtranspose:
return self.conv(x)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# https://github.com/pytorch/pytorch/issues/86679
dtype = x.dtype
Expand Down Expand Up @@ -450,6 +467,7 @@ class Decoder(nn.Module):
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(
Expand All @@ -464,6 +482,7 @@ def __init__(
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
use_convtranspose: bool = False,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand Down Expand Up @@ -553,7 +572,9 @@ def __init__(
)

if not is_final_block:
blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch))
blocks.append(
Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose)
)

blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
blocks.append(
Expand Down Expand Up @@ -595,6 +616,8 @@ class AutoencoderKL(nn.Module):
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
use_checkpointing: if True, use activation checkpointing to save memory.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(
Expand All @@ -611,6 +634,8 @@ def __init__(
with_encoder_nonlocal_attn: bool = True,
with_decoder_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
use_checkpointing: bool = False,
use_convtranspose: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -658,6 +683,7 @@ def __init__(
attention_levels=attention_levels,
with_nonlocal_attn=with_decoder_nonlocal_attn,
use_flash_attention=use_flash_attention,
use_convtranspose=use_convtranspose,
)
self.quant_conv_mu = Convolution(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -687,6 +713,7 @@ def __init__(
conv_only=True,
)
self.latent_channels = latent_channels
self.use_checkpointing = use_checkpointing

def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -696,7 +723,10 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x: BxCx[SPATIAL DIMS] tensor

"""
h = self.encoder(x)
if self.use_checkpointing:
h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
else:
h = self.encoder(x)

z_mu = self.quant_conv_mu(h)
z_log_var = self.quant_conv_log_sigma(h)
Expand Down Expand Up @@ -747,7 +777,10 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
decoded image tensor
"""
z = self.post_quant_conv(z)
dec = self.decoder(z)
if self.use_checkpointing:
dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
else:
dec = self.decoder(z)
return dec

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
15 changes: 12 additions & 3 deletions generative/networks/nets/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class VQVAE(nn.Module):
dropout: dropout ratio.
output_act: activation type and arguments for the output.
ddp_sync: whether to synchronize the codebook across processes.
use_checkpointing if True, use activation checkpointing to save memory.
"""

def __init__(
Expand All @@ -321,6 +322,7 @@ def __init__(
act: tuple | str | None = Act.RELU,
output_act: tuple | str | None = None,
ddp_sync: bool = True,
use_checkpointing: bool = False,
):
super().__init__()

Expand All @@ -330,6 +332,7 @@ def __init__(
self.num_channels = num_channels
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.use_checkpointing = use_checkpointing

if isinstance(num_res_channels, int):
num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels))
Expand Down Expand Up @@ -412,14 +415,20 @@ def __init__(
)

def encode(self, images: torch.Tensor) -> torch.Tensor:
return self.encoder(images)
if self.use_checkpointing:
return torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False)
else:
return self.encoder(images)

def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x_loss, x = self.quantizer(encodings)
return x, x_loss

def decode(self, quantizations: torch.Tensor) -> torch.Tensor:
return self.decoder(quantizations)
if self.use_checkpointing:
return torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False)
else:
return self.decoder(quantizations)

def index_quantize(self, images: torch.Tensor) -> torch.Tensor:
return self.quantizer.quantize(self.encode(images=images))
Expand All @@ -434,7 +443,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return reconstruction, quantization_losses

def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
z = self.encoder(x)
z = self.encode(x)
e, _ = self.quantize(z)
return e

Expand Down
52 changes: 52 additions & 0 deletions tests/test_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
):
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

# def test_script(self):
# input_param, input_shape, _, _ = CASES[0]
# net = AutoencoderKL(**input_param)
Expand Down Expand Up @@ -195,6 +208,15 @@ def test_shape_reconstruction(self):
result = net.reconstruct(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):
input_param, input_shape, expected_shape, _ = CASES[0]
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.reconstruct(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_shape_encode(self):
input_param, input_shape, _, expected_latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
Expand All @@ -203,6 +225,16 @@ def test_shape_encode(self):
self.assertEqual(result[0].shape, expected_latent_shape)
self.assertEqual(result[1].shape, expected_latent_shape)

def test_shape_encode_with_convtranspose_and_checkpointing(self):
input_param, input_shape, _, expected_latent_shape = CASES[0]
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.encode(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_latent_shape)
self.assertEqual(result[1].shape, expected_latent_shape)

def test_shape_sampling(self):
input_param, _, _, expected_latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
Expand All @@ -212,13 +244,33 @@ def test_shape_sampling(self):
)
self.assertEqual(result.shape, expected_latent_shape)

def test_shape_sampling_convtranspose_and_checkpointing(self):
input_param, _, _, expected_latent_shape = CASES[0]
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.sampling(
torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
)
self.assertEqual(result.shape, expected_latent_shape)

def test_shape_decode(self):
input_param, expected_input_shape, _, latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.decode(torch.randn(latent_shape).to(device))
self.assertEqual(result.shape, expected_input_shape)

def test_shape_decode_convtranspose_and_checkpointing(self):
input_param, expected_input_shape, _, latent_shape = CASES[0]
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.decode(torch.randn(latent_shape).to(device))
self.assertEqual(result.shape, expected_input_shape)


if __name__ == "__main__":
unittest.main()
47 changes: 30 additions & 17 deletions tests/test_vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from parameterized import parameterized

from generative.networks.nets.vqvae import VQVAE
from tests.utils import test_script_save

TEST_CASES = [
[
Expand Down Expand Up @@ -113,22 +112,36 @@ def test_shape(self, input_param, input_shape, expected_shape):

self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_channels=(8, 8),
num_res_channels=(8, 8),
num_embeddings=16,
embedding_dim=8,
ddp_sync=False,
)
test_data = torch.randn(1, 1, 16, 16)
test_script_save(net, test_data)
@parameterized.expand(TEST_CASES)
def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
input_param = input_param.copy()
input_param.update({"use_checkpointing": True})

net = VQVAE(**input_param).to(device)

with eval_mode(net):
result, _ = net(torch.randn(input_shape).to(device))

self.assertEqual(result.shape, expected_shape)

# Removed this test case since TorchScript currently does not support activation checkpoint.
# def test_script(self):
# net = VQVAE(
# spatial_dims=2,
# in_channels=1,
# out_channels=1,
# downsample_parameters=((2, 4, 1, 1),) * 2,
# upsample_parameters=((2, 4, 1, 1, 0),) * 2,
# num_res_layers=1,
# num_channels=(8, 8),
# num_res_channels=(8, 8),
# num_embeddings=16,
# embedding_dim=8,
# ddp_sync=False,
# )
# test_data = torch.randn(1, 1, 16, 16)
# test_script_save(net, test_data)

def test_num_channels_not_same_size_of_num_res_channels(self):
with self.assertRaises(ValueError):
Expand Down