Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def main():

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
vae.enable_gradient_checkpointing()

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

_supports_gradient_checkpointing = True

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -121,6 +123,10 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
SlicedAttnProcessor,
XFormersAttnProcessor,
)
from .attention_processor import ( # noqa: F401
AttnProcessor as AttnProcessorRename,
)
from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401


deprecate(
Expand Down
71 changes: 59 additions & 12 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.conv_in = torch.nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)

self.mid_block = None
self.down_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -96,16 +102,34 @@ def __init__(
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, x):
sample = x
sample = self.conv_in(sample)

# down
for down_block in self.down_blocks:
sample = down_block(sample)
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# down
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)

# middle
sample = self.mid_block(sample)
# middle
sample = self.mid_block(sample)

# post-process
sample = self.conv_norm_out(sample)
Expand All @@ -129,7 +153,13 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)

self.mid_block = None
self.up_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -176,16 +206,33 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, z):
sample = z
sample = self.conv_in(sample)

# middle
sample = self.mid_block(sample)
if self.training and self.gradient_checkpointing:

# up
for up_block in self.up_blocks:
sample = up_block(sample)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle
sample = self.mid_block(sample)

# up
for up_block in self.up_blocks:
sample = up_block(sample)

# post-process
sample = self.conv_norm_out(sample)
Expand Down
41 changes: 41 additions & 0 deletions tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,47 @@ def test_forward_signature(self):
def test_training(self):
pass

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
Expand Down