diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 73ea36ea..733fcdb4 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -61,7 +61,18 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + x = self.conv(x) return x diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 271f85e4..b5ec38d8 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -576,7 +576,19 @@ def __init__( def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: del emb assert x.shape[1] == self.num_channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + if self.use_conv: x = self.conv(x) return x @@ -1783,6 +1795,11 @@ def forward( """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) # 2. class @@ -1790,6 +1807,7 @@ def forward( if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb # 3. initial convolution