diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b9718e67f279..68a5341d513f 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -40,6 +40,13 @@ def forward(self, hidden_states, output_size=None): if self.use_conv_transpose: return self.conv(hidden_states) + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: @@ -47,6 +54,10 @@ def forward(self, hidden_states, output_size=None): else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 00e72de6551a..bd1aade0bc01 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -327,7 +327,9 @@ def __call__( image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 773a7d4b2107..eedd74e88dfb 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -36,8 +36,9 @@ def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0]