Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jul 22, 2023

What does this PR do?

🚨🚨🚨 1. Breaking change - fixes mask input 🚨🚨🚨

NOW: mask_image repaints white pixels and preserves black pixels

Kandinksy was using an incorrect mask format. Instead of using white pixels as a mask (like SD & IF do), Kandinsky models were using black pixels. This needs to be corrected and so that the diffusers API is aligned. We cannot have different mask formats for different pipelines.

Important => This means that everyone that already used Kandinsky Inpaint in production / pipeline now needs to change the mask to:

# For PIL input
import PIL.ImageOps
mask = PIL.ImageOps.invert(mask)

# For PyTorch and Numpy input
mask = 1 - mask

Once this PR is merged we also need to correct all the model cards (cc @yiyixuxu)

2. Adds combined pipelines

As noticed in #4161 by @vladmandic , diffusers currently has an inconsistent design between Kandinsky Pipelines and other pipelines. The reason for this is that all Kandinsky pipelines (txt2img, img2img & inpaint) are based on Dalle2's UnCLIP design meaning they have to run two diffusion pipelines:

  • prior which diffuses text embeddings to image embeddings which is the same for all t2i, img2img, inpaint)
  • decoder which diffuses image embeddings to images (each t2i, img2img, inpaint have different pipelines)

Running just the prior or the decoder on its own often makes no sense so we should give the user an easier UX here while making sure we still keep the pipelines separated so that they can be run independently (e.g. on different nodes).

This PR introduces a mechanism that allows to load required prior pipelines directly when loading a decoder pipeline and puts all components in a single "combined" pipeline. Required pipelines are defined in the decoder's model card here: https://huggingface.co/kandinsky-community/kandinsky-2-1/blob/73bf6fba5b4410c671f7c73279ab39932b3ad021/README.md?code=true#L4

Each decoder (txt2img, img2img & inpaint) therefore is now accompanied by a "Combined" pipeline that will automatically be called from AutoPipelineFor{Text2Img,Img2Img,Inpaint}.

The following use cases are now supported and thereby make sure Kandinsky models can be used with the same API as other models:

Text 2 Image

#!/usr/bin/env python3
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
# or pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)

prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"

image = pipe(prompt=prompt, num_inference_steps=25).images[0]

Img2Img

from diffusers import AutoPipelineForImage2Image
import torch
import requests
from io import BytesIO
from PIL import Image
import os

pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
# or pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

prompt = "A fantasy landscape, Cinematic lighting"
negative_prompt = "low quality, bad quality"

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
 
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image.thumbnail((768, 768))

image = pipe(prompt=prompt, image=original_image,num_inference_steps=25).images[0]

Inpaint

from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
import torch
import numpy as np

pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
# or pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

prompt = "A fantasy landscape, Cinematic lighting"
negative_prompt = "low quality, bad quality"

original_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)

mask = np.zeros((768, 768), dtype=np.float32)
# Let's mask out an area above the cat's head
mask[:250, 250:-250] = 1

image = pipe(prompt=prompt, image=original_image,mask_image=mask, num_inference_steps=25).images[0]

To achieve this the following pipelines have been added:

KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyInpaintCombinedPipeline,
KandinskyV22CombinedPipeline,
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22InpaintCombinedPipeline,

Edit: updated mask in inpaint example.

@patrickvonplaten patrickvonplaten mentioned this pull request Jul 25, 2023
2 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@vladmandic
Copy link
Contributor

nice - thank you @patrickvonplaten

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! makes sense to me

mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)

mask = 1 - mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu we need to do a breaking change here because Kandinsky can't accept a different mask format than SD.

Let's make sure in the future that our APIs are always exactly aligned. We cannot have one pipeline having black pixels representing a mask and another having white pixels representing a mask.

mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)

mask = 1 - mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu we need to do a breaking change here because Kandinsky can't accept a different mask format than SD.

Let's make sure in the future that our APIs are always exactly aligned. We cannot have one pipeline having black pixels representing a mask and another having white pixels representing a mask.

@patrickvonplaten patrickvonplaten changed the title Add combined pipeline kandinsky [Kandinsky] Add combined pipelines / Fix cpu model offload / Fix inpainting Jul 25, 2023
@patrickvonplaten
Copy link
Contributor Author

Doing tests & docs tomorrow

Comment on lines +215 to +217
mask = np.zeros((768, 768), dtype=np.float32)
# Let's mask out an area above the cat's head
mask[:250, 250:-250] = 0
mask[:250, 250:-250] = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe add a "Tip warning" here to ensure the users are communicated about this through and through?

Maybe something like:

<Tip warning=true>

Note that the above change was introduced in the following pull request: https://github.com/huggingface/diffusers/pull/4207. So, if you're using a source installation of `diffusers` or the latest release, you should upgrade your inpainting code to follow the above. 

</Tip>

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the comment in the PR description is wrong (just in case we want to copy/paste it somewhere)

Comment on lines 89 to 100
AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky", KandinskyPipeline),
("kandinsky22", KandinskyV22Pipeline),
]
)
AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky", KandinskyImg2ImgPipeline),
("kandinsky22", KandinskyV22Img2ImgPipeline),
]
)
Copy link
Member

@sayakpaul sayakpaul Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should potentially add similar auto-classes for Stable unCLIP as well, no (follow-up)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd say follow-up

@sayakpaul
Copy link
Member

Currently, installing from this branch leads to:

/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py in <module>
     61         get_scheduler,
     62     )
---> 63     from .pipelines import (
     64         AudioPipelineOutput,
     65         AutoPipelineForImage2Image,

/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/__init__.py in <module>
     18     from ..utils.dummy_pt_objects import *  # noqa F403
     19 else:
---> 20     from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
     21     from .consistency_models import ConsistencyModelPipeline
     22     from .dance_diffusion import DanceDiffusionPipeline

/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/auto_pipeline.py in <module>
     18 
     19 from ..configuration_utils import ConfigMixin
---> 20 from .controlnet import (
     21     StableDiffusionControlNetImg2ImgPipeline,
     22     StableDiffusionControlNetInpaintPipeline,

ImportError: cannot import name 'StableDiffusionXLControlNetPipeline' from 'diffusers.pipelines.controlnet' (/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/controlnet/__init__.py)

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice!

However, I am unable to run the code snippets when diffusers is installed from the PR branch (pip install -q git+https://github.com/huggingface/diffusers@add_combined_pipeline_kandinsky).

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another one of these 🤯 PRs 😂
thank you!

return getattr(diffusers_module, config["_class_name"])
pipeline_cls = getattr(diffusers_module, config["_class_name"])

if load_connected_pipeline:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is this load_connected_pipeline needed? when do we need to set this flag to be True?

so far it seems like we are able to find the correct combined pipeline class just using AutoPipeline.from_pretrained()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I need to write some better docs about it (will do in a follow-up PR). Essentially in case wants to load a connected pipeline via DiffusionPipeline

@yiyixuxu
Copy link
Collaborator

are we not going to add kandinsky controlnet to sd.next?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very awesome, we could potentially use the same mechanism for SDXL.

Comment on lines +215 to +217
mask = np.zeros((768, 768), dtype=np.float32)
# Let's mask out an area above the cat's head
mask[:250, 250:-250] = 0
mask[:250, 250:-250] = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the comment in the PR description is wrong (just in case we want to copy/paste it somewhere)

]
)

AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what context will these be used?

I'm thinking about potentially creating auto pipelines for SDXL, the combined one is clear, but where would the base and refiner steps go? Is Decoder a good name or too specific to these models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, it's not in the public API so we can definitely change the naming later. I'll make those mappings private for now to allow them to be changed after

Comment on lines +259 to +264
prior_prior: PriorTransformer,
prior_image_encoder: CLIPVisionModelWithProjection,
prior_text_encoder: CLIPTextModelWithProjection,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: UnCLIPScheduler,
prior_image_processor: CLIPImageProcessor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No docstrings (not sure if they are important)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding all those now :-)

_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

# We'll offload the last model manually.
self.prior_hook = hook
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to store this? Should be used in interpolate?

Comment on lines +565 to +566
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.prior_hook.offload()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.prior_hook.offload()
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()

If I am understanding this correctly we should always offload final_offload_hook, so we could put it outside the if.

Are we missing a self.prior_hook.offload() in interpolate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually expected since for this branch the prior is the final model

@patrickvonplaten
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten merged commit b3e5cd6 into main Jul 26, 2023
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…inting (huggingface#4207)

* Add combined pipeline

* Download readme

* Upload

* up

* up

* fix final

* Add enable model cpu offload kandinsky

* finish

* finish

* Fix

* fix more

* make style

* fix kandinsky mask

* fix inpainting test

* add callbacks

* add tests

* fix tests

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* docs

* docs

* correct docs

* fix tests

* add warning

* correct docs

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…inting (huggingface#4207)

* Add combined pipeline

* Download readme

* Upload

* up

* up

* fix final

* Add enable model cpu offload kandinsky

* finish

* finish

* Fix

* fix more

* make style

* fix kandinsky mask

* fix inpainting test

* add callbacks

* add tests

* fix tests

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* docs

* docs

* correct docs

* fix tests

* add warning

* correct docs

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…inting (huggingface#4207)

* Add combined pipeline

* Download readme

* Upload

* up

* up

* fix final

* Add enable model cpu offload kandinsky

* finish

* finish

* Fix

* fix more

* make style

* fix kandinsky mask

* fix inpainting test

* add callbacks

* add tests

* fix tests

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* docs

* docs

* correct docs

* fix tests

* add warning

* correct docs

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
@kashif kashif deleted the add_combined_pipeline_kandinsky branch September 11, 2023 19:07
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…inting (huggingface#4207)

* Add combined pipeline

* Download readme

* Upload

* up

* up

* fix final

* Add enable model cpu offload kandinsky

* finish

* finish

* Fix

* fix more

* make style

* fix kandinsky mask

* fix inpainting test

* add callbacks

* add tests

* fix tests

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* docs

* docs

* correct docs

* fix tests

* add warning

* correct docs

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…inting (huggingface#4207)

* Add combined pipeline

* Download readme

* Upload

* up

* up

* fix final

* Add enable model cpu offload kandinsky

* finish

* finish

* Fix

* fix more

* make style

* fix kandinsky mask

* fix inpainting test

* add callbacks

* add tests

* fix tests

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* docs

* docs

* correct docs

* fix tests

* add warning

* correct docs

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants