From 187f9fc827888341afda97f4e0d8101afc98107d Mon Sep 17 00:00:00 2001 From: dzy7e <228218809@qq.com> Date: Fri, 12 May 2023 09:58:33 +0800 Subject: [PATCH 1/4] gradient checkpointing bug fix --- src/diffusers/models/unet_2d_blocks.py | 132 +++++++++++++----- src/diffusers/models/vae.py | 33 +++-- .../versatile_diffusion/modeling_text_unet.py | 58 +++++--- 3 files changed, 161 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2f7b19b7328a..6dce45266bda 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -856,13 +856,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant = False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -947,7 +957,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1340,7 +1353,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1537,7 +1553,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1632,14 +1651,25 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1848,13 +1878,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant = False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1934,7 +1974,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2351,7 +2394,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2556,7 +2602,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2677,14 +2726,25 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 400c3030af90..d597b04b76a0 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -118,10 +118,14 @@ def custom_forward(*inputs): # down for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + if torch.__version__ >= "1.11.0": + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, use_reentrant=False) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) + else: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down @@ -221,13 +225,22 @@ def custom_forward(*inputs): return custom_forward - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) - sample = sample.to(upscale_dtype) + if torch.__version__>="1.11.0": + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) + sample = sample.to(upscale_dtype) - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, use_reentrant=False) + else: + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle sample = self.mid_block(sample) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index f0a210339c46..2b9de2b23c3d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1077,7 +1077,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1198,13 +1201,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1289,7 +1302,10 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1412,13 +1428,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if torch.__version__>="1.11.0": + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( From ab6522a58c9b35dae0959dd952001c4573f5eb32 Mon Sep 17 00:00:00 2001 From: dzy7e <228218809@qq.com> Date: Tue, 16 May 2023 20:18:14 +0800 Subject: [PATCH 2/4] bug fix; changes for reviews --- src/diffusers/models/unet_2d_blocks.py | 21 ++++++++++--------- src/diffusers/models/vae.py | 19 +++++++++-------- .../versatile_diffusion/modeling_text_unet.py | 10 ++++----- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f35a08187b96..bd1105bbae04 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch import nn +from ..utils import is_torch_version from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel @@ -866,7 +867,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), @@ -967,7 +968,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -1374,7 +1375,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -1574,7 +1575,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -1672,7 +1673,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), @@ -1904,7 +1905,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), @@ -2000,7 +2001,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -2431,7 +2432,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -2639,7 +2640,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -2763,7 +2764,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index d597b04b76a0..671a0ec5c144 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, randn_tensor, is_torch_version from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -117,15 +117,16 @@ def custom_forward(*inputs): return custom_forward # down - for down_block in self.down_blocks: - if torch.__version__ >= "1.11.0": + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, use_reentrant=False) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) - else: + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) + else: + for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down @@ -225,7 +226,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) sample = sample.to(upscale_dtype) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 2b9de2b23c3d..c327d0fef0a3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -18,7 +18,7 @@ from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging +from ...utils import logging, is_torch_version logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1077,7 +1077,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -1201,7 +1201,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), @@ -1302,7 +1302,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) else: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) @@ -1428,7 +1428,7 @@ def custom_forward(*inputs): return custom_forward - if torch.__version__>="1.11.0": + if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), From 94c8dcaadfbe043a20adf52d62ef85ac5189f1b3 Mon Sep 17 00:00:00 2001 From: dzy7e <228218809@qq.com> Date: Tue, 16 May 2023 20:38:40 +0800 Subject: [PATCH 3/4] reformat --- src/diffusers/models/unet_2d_blocks.py | 84 ++++++++++++++----- src/diffusers/models/vae.py | 16 +++- .../versatile_diffusion/modeling_text_unet.py | 32 +++++-- 3 files changed, 98 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index bd1105bbae04..7b76dd7e37bd 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -868,16 +868,20 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_kwargs, - use_reentrant = False, + use_reentrant=False, )[0] else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -969,9 +973,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1376,9 +1384,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1576,9 +1588,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1674,7 +1690,9 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -1684,7 +1702,9 @@ def custom_forward(*inputs): use_reentrant=False, ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -1906,16 +1926,20 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_kwargs, - use_reentrant = False, + use_reentrant=False, )[0] else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -2002,9 +2026,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2433,9 +2461,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2641,9 +2673,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2765,7 +2801,9 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -2775,7 +2813,9 @@ def custom_forward(*inputs): use_reentrant=False, )[0] else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 671a0ec5c144..fb19dcd032ec 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -119,9 +119,13 @@ def custom_forward(*inputs): # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, use_reentrant=False) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) @@ -228,12 +232,16 @@ def custom_forward(*inputs): if is_torch_version(">=", "1.11.0"): # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, use_reentrant=False) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, use_reentrant=False) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, use_reentrant=False + ) else: # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index c327d0fef0a3..baa7fcf8d57c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1078,9 +1078,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1202,7 +1206,9 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -1211,7 +1217,9 @@ def custom_forward(*inputs): use_reentrant=False, )[0] else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -1303,9 +1311,13 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1429,7 +1441,9 @@ def custom_forward(*inputs): return custom_forward if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, @@ -1438,7 +1452,9 @@ def custom_forward(*inputs): use_reentrant=False, )[0] else: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, From eb5fc40a65a8303868204190ed7149afe13693dc Mon Sep 17 00:00:00 2001 From: dzy7e <228218809@qq.com> Date: Tue, 16 May 2023 20:46:10 +0800 Subject: [PATCH 4/4] reformat --- src/diffusers/models/vae.py | 2 +- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index fb19dcd032ec..6f8514f28d33 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ..utils import BaseOutput, randn_tensor, is_torch_version +from ..utils import BaseOutput, is_torch_version, randn_tensor from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index baa7fcf8d57c..7aaa0e49e1da 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -18,7 +18,7 @@ from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging, is_torch_version +from ...utils import is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name