From cfd8e46d78df9ec1242720fd48e83046207cff3f Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 24 Dec 2022 12:24:54 +0000 Subject: [PATCH 1/3] Fix F.interpolate usage with bfloat16 Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 11 +++++++++++ generative/networks/nets/diffusion_model_unet.py | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index f1a75636..8894ed74 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -45,7 +45,18 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + x = self.conv(x) return x diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7aabfbba..7752221b 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -515,7 +515,19 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.shape[1] == self.num_channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + if self.use_conv: x = self.conv(x) return x From 214f1a3e32ab43fddb4643a4db1c2e4031070d6c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 26 Dec 2022 14:47:30 +0000 Subject: [PATCH 2/3] [WIP] Fix half precision for DiffusionModelUNet Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7752221b..cd85cda2 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1622,6 +1622,11 @@ def forward( """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embed(t_emb) # 2. class @@ -1629,6 +1634,7 @@ def forward( if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=self.dtype) emb = emb + class_emb # 3. initial convolution From 3fd1638d76c91df4835126b715355e90d6fa00e3 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 26 Dec 2022 14:51:00 +0000 Subject: [PATCH 3/3] [WIP] Fix half precision for DiffusionModelUNet Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index cd85cda2..48a9b3f6 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1626,7 +1626,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) # 2. class @@ -1634,7 +1634,7 @@ def forward( if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=self.dtype) + class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb # 3. initial convolution