From 2c86a944208a5c8acfac22a2fe967f94ee11ad1b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 10 Oct 2022 12:38:00 +0200 Subject: [PATCH 1/3] support bf16 for stable diffusion --- src/diffusers/models/resnet.py | 9 +++++++++ .../stable_diffusion/pipeline_stable_diffusion.py | 8 +++++++- .../pipelines/stable_diffusion/safety_checker.py | 12 ++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b9718e67f279..5b1f88e5416e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -40,6 +40,11 @@ def forward(self, hidden_states, output_size=None): if self.use_conv_transpose: return self.conv(hidden_states) + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: @@ -47,6 +52,10 @@ def forward(self, hidden_states, output_size=None): else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 00e72de6551a..335d442437f9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -327,7 +327,13 @@ def __call__( image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = image.cpu().permute(0, 2, 3, 1) + + # cast to float32 to as numpy doesn't support bfloat16 + if image.dtype == torch.bfloat16: + image = image.to(torch.float32).numpy() + else: + image = image.numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 773a7d4b2107..2dbd0bbba6eb 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -36,8 +36,16 @@ def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu() + + # cast to float32 to as numpy does not support bfloat16 + if image_embeds == torch.bfloat16: + special_cos_dist = special_cos_dist.float().numpy() + cos_dist = cos_dist.float().numpy() + else: + special_cos_dist = special_cos_dist.numpy() + cos_dist = cos_dist.numpy() result = [] batch_size = image_embeds.shape[0] From 90385bbd59f9011f2b12b71ec436df8a9ea50a98 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 10 Oct 2022 12:49:09 +0200 Subject: [PATCH 2/3] fix typo --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 2dbd0bbba6eb..68e4fef6bb6e 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -40,7 +40,7 @@ def forward(self, clip_input, images): cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu() # cast to float32 to as numpy does not support bfloat16 - if image_embeds == torch.bfloat16: + if image_embeds.dtype == torch.bfloat16: special_cos_dist = special_cos_dist.float().numpy() cos_dist = cos_dist.float().numpy() else: From 880f9f9ae1bf27a4b16e34f7194f515f9d99575e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 11:31:59 +0200 Subject: [PATCH 3/3] address review comments --- src/diffusers/models/resnet.py | 2 ++ .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++------ .../pipelines/stable_diffusion/safety_checker.py | 13 +++---------- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 5b1f88e5416e..68a5341d513f 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -41,6 +41,8 @@ def forward(self, hidden_states, output_size=None): return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 335d442437f9..bd1aade0bc01 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -327,13 +327,9 @@ def __call__( image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1) - # cast to float32 to as numpy doesn't support bfloat16 - if image.dtype == torch.bfloat16: - image = image.to(torch.float32).numpy() - else: - image = image.numpy() + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 68e4fef6bb6e..eedd74e88dfb 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -36,16 +36,9 @@ def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu() - - # cast to float32 to as numpy does not support bfloat16 - if image_embeds.dtype == torch.bfloat16: - special_cos_dist = special_cos_dist.float().numpy() - cos_dist = cos_dist.float().numpy() - else: - special_cos_dist = special_cos_dist.numpy() - cos_dist = cos_dist.numpy() + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0]