diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index f91aa2f8d29b..637b35b3f695 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -412,6 +412,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + vae.enable_gradient_checkpointing() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 3ee0c56796fe..9c0161065e4c 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -121,6 +123,10 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 91a0dbfa1238..4fdb2acaabed 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -24,9 +24,7 @@ SlicedAttnProcessor, XFormersAttnProcessor, ) -from .attention_processor import ( # noqa: F401 - AttnProcessor as AttnProcessorRename, -) +from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 deprecate( diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c5142a8f15b7..b4484823ac3d 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -50,7 +50,13 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) self.mid_block = None self.down_blocks = nn.ModuleList([]) @@ -96,16 +102,34 @@ def __init__( conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + self.gradient_checkpointing = False + def forward(self, x): sample = x sample = self.conv_in(sample) - # down - for down_block in self.down_blocks: - sample = down_block(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) - # middle - sample = self.mid_block(sample) + # middle + sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) @@ -129,7 +153,13 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -176,16 +206,33 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.gradient_checkpointing = False + def forward(self, z): sample = z sample = self.conv_in(sample) - # middle - sample = self.mid_block(sample) + if self.training and self.gradient_checkpointing: - # up - for up_block in self.up_blocks: - sample = up_block(sample) + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + else: + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) # post-process sample = self.conv_norm_out(sample) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 5d0aa194c1df..3eb7ce861592 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -68,6 +68,47 @@ def test_forward_signature(self): def test_training(self): pass + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + def test_gradient_checkpointing(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < 1e-5) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_from_pretrained_hub(self): model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) self.assertIsNotNone(model)