From 4d53a521508955e47b8bdac2f76891136135ad16 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:44:27 +0200 Subject: [PATCH] add unet ldm in init --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_ldm.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3ce4142f65c2..8feb9e81ad26 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,5 +7,6 @@ from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel +from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline from .schedulers.gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 85f1cc03f667..6d6c4d3d08a9 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,3 +18,4 @@ from .unet import UNetModel from .unet_glide import UNetGLIDEModel +from .unet_ldm import UNetLDMModel diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 465c168c83ad..57dec0b60696 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -830,7 +830,7 @@ def __init__( self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype_ = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -1060,7 +1060,7 @@ def forward(self, x, timesteps=None, context=None, y=None,**kwargs): assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x.type(self.dtype_) for module in self.input_blocks: h = module(h, emb, context) hs.append(h)