Skip to content

Conversation

@dg845
Copy link
Collaborator

@dg845 dg845 commented Mar 28, 2023

This PR adds a warning in StableDiffusionInpaintPipeline's __init__ method if the user initializes the pipeline with a UNet with unet.config.in_channels other than 9. This follows up on #2799 .

This is because a Stable Diffusion inpainting model typically has 9 input channels: 4 for the latent (encoded) base image (num_channels_latents), 1 for the inpainting mask (num_channels_mask), and 4 for the latent masked image (num_channels_masked_image). However, a valid SD inpainting model could have a different number of input channels, which is why we raise a warning instead of an exception; the precise condition that must be satisfied is that num_channels_latents + num_channels_mask + num_channels_masked_image == self.unet.config.in_channels, which is checked in the __call__ method. See #2799 for more details.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done!

@patrickvonplaten I don't think we need add a test for this?

@patrickvonplaten
Copy link
Contributor

This looks great, nice warning message @dg845

@patrickvonplaten patrickvonplaten merged commit 4d0f412 into huggingface:main Mar 28, 2023
@dg845 dg845 deleted the sd-inpaint-check-shapes branch March 28, 2023 22:15
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
…uggingface#2853)

Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…uggingface#2853)

Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…uggingface#2853)

Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants