From ef41671249dc79651b9b020140eef390f8faa683 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Mon, 12 Jun 2023 18:17:22 +0300 Subject: [PATCH 1/3] added the _default_height_width function. --- .../stable_diffusion_controlnet_reference.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index ca06136d7829..baea508be6ed 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -63,6 +63,31 @@ def torch_dfs(model: torch.nn.Module): class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device, dtype=dtype) From 2e32c3423a09f722be2271f5ac9c23081dcb157d Mon Sep 17 00:00:00 2001 From: kadirnar Date: Mon, 12 Jun 2023 18:26:34 +0300 Subject: [PATCH 2/3] The check_inputs parameters have been updated. --- examples/community/stable_diffusion_controlnet_reference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index baea508be6ed..36c80b85f19d 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -255,8 +255,6 @@ def __call__( self.check_inputs( prompt, image, - height, - width, callback_steps, negative_prompt, prompt_embeds, From 0d48e128b5ad4870d44fd771c744c616405a4a13 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Mon, 12 Jun 2023 18:35:25 +0300 Subject: [PATCH 3/3] make style --- examples/community/stable_diffusion_controlnet_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index 36c80b85f19d..e974898f5f6b 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -87,7 +87,7 @@ def _default_height_width(self, height, width, image): width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width - + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device, dtype=dtype)