Skip to content

Conversation

@kfzyqin
Copy link
Contributor

@kfzyqin kfzyqin commented Aug 21, 2023

Overview:

This PR introduces the implementation of the inference pipeline for ControlNet with SDXL and inpainting.

Files Modified/Added:

  1. Inference Pipeline: srcs/pipelines/controlnet/pipeline_control_inpaint_sd_xl.py
    • This file contains the main implementation of the inference pipeline for ControlNet with SDXL and inpainting.
  2. Unit Test: tests/pipelines/controlnet/test_controlnet_inpaint_sdx.py
    • This file provides the unit tests to ensure the correct functionality and robustness of the implemented pipeline.

Visualizations:

To better understand the impact and functionality of the implemented pipeline, the following visualizations are provided:

  1. Input Image
  2. Mask
  3. Output Image

Overview:

This PR introduces the implementation of the inference pipeline for ControlNet with SDXL and inpainting.

Files Modified/Added:

  1. Inference Pipeline: srcs/pipelines/controlnet/pipeline_control_inpaint_sd_xl.py
    • This file contains the main implementation of the inference pipeline for ControlNet with SDXL and inpainting.
  2. Unit Test: tests/pipelines/controlnet/test_controlnet_inpaint_sdx.py
    • This file provides the unit tests to ensure the correct functionality and robustness of the implemented pipeline.

Example Usage

import torch 
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
import numpy as np 
import cv2 


def get_depth_map(image):
    depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
    feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(512, 512),
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)

    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image

def inpaint_with_controlnet():
    import torch
    from diffusers import StableDiffusionXLInpaintPipeline
    from diffusers.utils import load_image
    from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
    from diffusers import StableDiffusionXLControlNetInpaintPipeline

    img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
    mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

    controlnet = [
        # ControlNetModel.from_pretrained(
        #     "diffusers/controlnet-depth-sdxl-1.0", use_auth_token=True, torch_dtype=torch.float32
        # ), 
        ControlNetModel.from_pretrained(
            "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float32
        ),
    ]

    pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        controlnet=controlnet,
        torch_dtype=torch.float32, 
    )
    pipe.to("cuda")

    init_image = load_image(img_url).convert("RGB")
    depth_image = get_depth_map(init_image)
    
    canny_image = np.array(init_image)

    low_threshold = 100
    high_threshold = 200

    canny_image = cv2.Canny(canny_image, low_threshold, high_threshold)

    # zero out middle columns of image where pose will be overlayed
    zero_start = canny_image.shape[1] // 4
    zero_end = zero_start + canny_image.shape[1] // 2
    canny_image[:, zero_start:zero_end] = 0

    canny_image = canny_image[:, :, None]
    canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
    canny_image = Image.fromarray(canny_image).resize((1024, 1024))
    
    mask_image = load_image(mask_url).convert("RGB")
    
    original_width, original_height = init_image.size
    new_width = int(original_width / 2)
    new_height = int(original_height / 2)
    init_image = init_image.resize((new_width, new_height))
    mask_image = mask_image.resize((new_width, new_height))
    depth_image = depth_image.resize((new_width, new_height))
    canny_image = canny_image.resize((new_width, new_height))
    
    prompt = "black cat with green eyes"
    strength=1.0
    controlnet_conditioning_scale = 0.3

    depth_image.save('control_image.jpg')
    image = pipe(
        prompt=prompt,
        image=init_image,
        mask_image=mask_image,
        control_image=[depth_image],
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        strength=strength,
        width=1024, 
        height=1024, 
    ).images[0]

    image.save('result_sdxl_inpaint.jpg')
    
    
if __name__ == "__main__":
    inpaint_with_controlnet()

Features

  • Support MultiControlNet
  • Compatible with new HF code

@kfzyqin kfzyqin marked this pull request as draft August 21, 2023 21:59
@kfzyqin kfzyqin changed the title [ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL [(Draft) ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL Aug 21, 2023
@Cathy0908
Copy link

Wow, I really need it. Can it work now? I always generate black pictures with it ? Can you post the api usage, thanks a lot !

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Aug 22, 2023

Wow, I really need it. Can it work now? I always generate black pictures with it ? Can you post the api usage, thanks a lot !

I discovered some issues today, but it should generate sensible images, rather than black ones ...

Let me complete this by this week.

Feel free to add my discord: harutatsuakiyama

@kfzyqin kfzyqin marked this pull request as ready for review August 22, 2023 11:53
@kfzyqin kfzyqin changed the title [(Draft) ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL [ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL Aug 22, 2023
@kfzyqin
Copy link
Contributor Author

kfzyqin commented Aug 22, 2023

Wow, I really need it. Can it work now? I always generate black pictures with it ? Can you post the api usage, thanks a lot !

I fixed the issue yesterday. The code should work as expected.

@Cathy0908
Copy link

I use the following pipeline, but still generate black image.
And I replace StableDiffusionXLControlNetInpaintPipeline with StableDiffusionXLInpaintPipeline, it works well.
Is there something wrong with my code?

def inpaint_with_controlnet():
    import torch
    from diffusers import StableDiffusionXLInpaintPipeline
    from diffusers.utils import load_image
    from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
    from pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline

    img_url = "https://user-images.githubusercontent.com/8084808/262496067-e01fb3c9-aece-4560-ae64-6354fdd789d7.png"
    mask_url = "https://user-images.githubusercontent.com/8084808/262496139-234e0049-43ab-415b-ae6d-4cbb96055f6d.png"
    control_image_url = img_url

    # Compute openpose conditioning image.
    from controlnet_aux import OpenposeDetector
    openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
    control_image = openpose(load_image(control_image_url))

    controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16)

    pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        controlnet=controlnet,
        torch_dtype=torch.float16, 
    )
    pipe.to("cuda")

    init_image = load_image(img_url).convert("RGB")
    mask_image = load_image(mask_url).convert("RGB")

    prompt = "hand"
    strength=0.5
    controlnet_conditioning_scale = 1.0

    image = pipe(
        prompt=prompt,
        image=init_image,
        mask_image=mask_image,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        strength=strength,
    ).images[0]

    image.save('result.jpg')

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Aug 23, 2023

def inpaint_with_controlnet():
    import torch
    from diffusers import StableDiffusionXLInpaintPipeline
    from diffusers.utils import load_image
    from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
    from pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline

    img_url = "https://user-images.githubusercontent.com/8084808/262496067-e01fb3c9-aece-4560-ae64-6354fdd789d7.png"
    mask_url = "https://user-images.githubusercontent.com/8084808/262496139-234e0049-43ab-415b-ae6d-4cbb96055f6d.png"
    control_image_url = img_url

    # Compute openpose conditioning image.
    from controlnet_aux import OpenposeDetector
    openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
    control_image = openpose(load_image(control_image_url))

    controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16)

    pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        controlnet=controlnet,
        torch_dtype=torch.float16, 
    )
    pipe.to("cuda")

    init_image = load_image(img_url).convert("RGB")
    mask_image = load_image(mask_url).convert("RGB")

    prompt = "hand"
    strength=0.5
    controlnet_conditioning_scale = 1.0

    image = pipe(
        prompt=prompt,
        image=init_image,
        mask_image=mask_image,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        strength=strength,
    ).images[0]

    image.save('result.jpg')

Thank you for the code! You need to use torch.float32 instead of torch.float16. I tested the following code, should work:

def inpaint_with_controlnet():
    import torch
    from diffusers import StableDiffusionXLInpaintPipeline
    from diffusers.utils import load_image
    from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
    from diffusers import StableDiffusionXLControlNetInpaintPipeline

    img_url = "https://user-images.githubusercontent.com/8084808/262496067-e01fb3c9-aece-4560-ae64-6354fdd789d7.png"
    mask_url = "https://user-images.githubusercontent.com/8084808/262496139-234e0049-43ab-415b-ae6d-4cbb96055f6d.png"
    control_image_url = img_url

    # Compute openpose conditioning image.
    from controlnet_aux import OpenposeDetector
    openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
    control_image = openpose(load_image(control_image_url))

    controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float32)

    pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        controlnet=controlnet,
        torch_dtype=torch.float32, 
    )
    pipe.to("cuda")

    init_image = load_image(img_url).convert("RGB")
    mask_image = load_image(mask_url).convert("RGB")
    
    original_width, original_height = init_image.size
    new_width = int(original_width / 2)
    new_height = int(original_height / 2)
    init_image = init_image.resize((new_width, new_height))
    mask_image = mask_image.resize((new_width, new_height))
    control_image = control_image[0].resize((new_width, new_height))

    prompt = "hand"
    strength=0.5
    controlnet_conditioning_scale = 1.0

    image = pipe(
        prompt=prompt,
        image=init_image,
        mask_image=mask_image,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        strength=strength,
    ).images[0]

    image.save('result.jpg')
    
    
if __name__ == "__main__":
    inpaint_with_controlnet()

Feel free to add my discord and we can discuss there.

@patrickvonplaten
Copy link
Contributor

Very cool PR! @yiyixuxu can you give this a look? :-)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! excellent work!

I think 2 main thing left are:

  1. Refactor with a mask_image_processor https://github.com/huggingface/diffusers/pull/4444/files
  2. Add MultiControlnet support



# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just deprecated this function :)
in this PR #4444 (comment)
let's update this PR too

Copy link
Contributor Author

@kfzyqin kfzyqin Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
self.control_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False)

self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.watermark = StableDiffusionXLWatermarker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a mask_processor here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

generator = torch.Generator(device=device).manual_seed(seed)

controlnet_embedder_scale_factor = 2
control_image = randn_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we accept image tensor in [0,1] range, so should not use randn_tensor here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Corrected.

control_image = (
            floats_tensor(
                (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
                rng=random.Random(seed),
            )
            .to(device)
            .cpu()
        )

init_image = init_image.cpu().permute(0, 2, 3, 1)[0]

controlnet_embedder_scale_factor = 2
image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dummy image and mask_image are just 2 black images here

let's do something similar as https://github.com/huggingface/diffusers/pull/4536/files#diff-b65a24df736726ca6f92c71567b77c2a9832ee6142ee2dcbdb08e9addcb6da4b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed the link's code,

image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
        image = image.cpu().permute(0, 2, 3, 1)[0]
        mask_image = torch.ones_like(image)
        controlnet_embedder_scale_factor = 2
        control_image = (
            floats_tensor(
                (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
                rng=random.Random(seed),
            )
            .to(device)
            .cpu()
        )

assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

# Ignore float16 for SDXL
def test_float16_inference(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we disable this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unintentional. Removed the disabling.

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Aug 27, 2023

Thank you @yiyixuxu and @patrickvonplaten. I will work on comments this week.

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Aug 29, 2023

Borrowing ideas of PR 4811. Working in progress.

@patrickvonplaten
Copy link
Contributor

Hey @viiika,

Could we maybe work on this PR together? @harutatsuakiyama can you maybe invite @viiika as a collaborator for this PR to your fork so that we can work here?

@viiika , it's quite rare that we have two PRs about the same feature popping up almost at the same time - very sorry for the potentially duplicated work. Would it be ok to pass onto this PR because:

  • we already reviewed this PR
  • The PR was up a bit earlier

That would be very nice if we could collaborate here 🙏

return mask


def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this function and instead use the new mask processor logic: #4444

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harutatsuakiyama I think you can delete this function now if not used?

@viiika
Copy link

viiika commented Aug 30, 2023

I still insist that #4811 already support some new features mentioned in #4694, like MultiControlnet, the api usage, no randn_tensor for control_image, even refactor with a mask_image_processor you mentioned just now, etc.

And the coding style is more consistent with pipeline_stable_diffusion_xl_inpaint, compared to StableDiffusionControlNetInpaintPipeline adapted from StableDiffusionInpaintPipeline.

I believe #4811 requires almost no effort to review, because it and the latest pipeline_stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint are updated synchronously.

Despite this, merge which PR depends you. And I believe if you choose #4811, it may take less than a day for us to merge.

@viiika
Copy link

viiika commented Aug 30, 2023

Also, if you still insist we should continue with #4694, that's fine with me and I can try my best to help fixing problems. I just think merging #4694 will take a few weeks to handle many problems, and might introduce some design inconsistencies. A lot of current research relies on this pipeline, so I just hope it gets merged soon.

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 1, 2023

Hi @yiyixuxu. Thanks for the review. I have addressed the review comments:

  • Update doc string.
  • Remove unnecessary functions.
  • Fix test errors.

My local tests show no issues. Please let me know if further changes are required :-)

] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
strength: float = 1.0,
strength: float =0.9999,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, but why?

The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 1.):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
strength (`float`, *optional*, defaults to 1.):
strength (`float`, *optional*, defaults to 0.9999):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, can I curiously ask why?


control_image = control_images
else:
assert False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert False
raise ValueError(f"{controlnet.__class__} is not supported.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to merge once @yiyixuxu is ok with it :-)

@patrickvonplaten
Copy link
Contributor

@viiika could you maybe drop your email here so that we can add you as a co-author via https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors

@viiika
Copy link

viiika commented Sep 1, 2023

@viiika could you maybe drop your email here so that we can add you as a co-author via https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors

Sure. My primary GitHub email for this account is 1355864570@qq.com. Thank you very much!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 1, 2023

@harutatsuakiyama
let's make sure the code quality checks pass. make style please :)

@patrickvonplaten
Copy link
Contributor

@viiika could you maybe drop your email here so that we can add you as a co-author via https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors

Sure. My primary GitHub email for this account is 1355864570@qq.com. Thank you very much!

@harutatsuakiyama could you add @viiika as an author here that would be very nice ❤️

Co-authored-by: Jiabin Bai 1355864570@qq.com
@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

Hi @yiyixuxu, @patrickvonplaten, and @viiika,

I have addressed the new code review comments:

  • Including @viiika as an author by including name and email in the commit
  • Change various number issues

For the failing tests, it seems previous failure was due to Internet issues (500 bad gate). My local tests can pass.

Please let me know if further changes are required.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 2, 2023

@harutatsuakiyama
Could you run make fix-copies and make style -
Let's make sure CI is green

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

Thank you @yiyixuxu. I just realized that diffusers.utils.dummy_torch_and_transformers_objects.py has some style problems. I have fixed them.

The following shows outputs of make fix-copies and make style. The errors of make style are not due to the code that I have uploaded. I think this time, the CI should be green :-)

Let me know if other things are required.

make fix-copies

python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite

make style

black examples scripts src tests utils
All done! ✨ 🍰 ✨
613 files left unchanged.
ruff examples scripts src tests utils --fix
examples/community/lpw_stable_diffusion_xl.py:1141:42: E721 Do not compare types, use `isinstance()`
examples/community/stable_diffusion_xl_reference.py:703:42: E721 Do not compare types, use `isinstance()`
src/diffusers/experimental/rl/value_guided_sampling.py:79:12: E721 Do not compare types, use `isinstance()`
src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py:181:12: E721 Do not compare types, use `isinstance()`
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py:827:42: E721 Do not compare types, use `isinstance()`
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py:909:20: E721 Do not compare types, use `isinstance()`
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py:1132:20: E721 Do not compare types, use `isinstance()`
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py:877:42: E721 Do not compare types, use `isinstance()`
tests/pipelines/consistency_models/test_consistency_models.py:190:12: E721 Do not compare types, use `isinstance()`
tests/pipelines/unidiffuser/test_unidiffuser.py:112:12: E721 Do not compare types, use `isinstance()`
tests/pipelines/unidiffuser/test_unidiffuser.py:548:12: E721 Do not compare types, use `isinstance()`
tests/pipelines/unidiffuser/test_unidiffuser.py:651:12: E721 Do not compare types, use `isinstance()`
Found 12 errors.
make: *** [Makefile:59: style] Error 1

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

Ahh I see, I need to run the test for doc builder. Let me do that. I aim that to be the last test.


Sorry for failing test again. Can I ask for hints about how to fix this error? @yiyixuxu Also, can we get access to run tests, for more efficient debugging purposes? I have tried locally, and seem to be correct ...

All done! ✨ 🍰 ✨
617 files would be left unchanged.
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.7.17/x64/bin/doc-builder", line 8, in <module>
    sys.exit(main())
  File "/opt/hostedtoolcache/Python/3.7.17/x64/lib/python3.7/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
    args.func(args)
  File "/opt/hostedtoolcache/Python/3.7.17/x64/lib/python3.7/site-packages/doc_builder/commands/style.py", line 28, in style_command
    raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 1 files should be restyled!
Error: Process completed with exit code 1.

>>> mask_image = load_image(mask_url).convert("RGB")

>>> original_width, original_height = init_image.size
>>> new_width = int(original_width / 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we resize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to save CUDA memory. Removed in the new code.

self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use a custom type PipelineImageInput (was recently introduced)

List[PIL.Image.Image],
List[np.ndarray],
] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mask_image should be of same type as image no? PipelineImageInput

latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this line is needed? it has not changed from line 1452

projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to fix the seed here? I don't think we have any randomness here, no?

Copy link
Contributor Author

@kfzyqin kfzyqin Sep 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS

def get_dummy_components(self):
torch.manual_seed(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 2, 2023

regards to the quality test, make sure you are up to date? pip install --upgrade -e .["quality"]

cc @DN6 here we need help with tests!

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

I found out the test issues, some lines in doc_string is too long.

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

Hi @yiyixuxu. I removed EXAMPLE_DOC_STRING since it keeps getting errors for doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source. In the future, I will try getting it back, maybe need some help from the test experts :-)

For now, I strongly believe the code should be able to pass tests (finger crossed 🙏)

@kfzyqin
Copy link
Contributor Author

kfzyqin commented Sep 2, 2023

Hi @yiyixuxu, thanks for the new review round. I have addressed the comments:

  • Code now uses PipelineImageInput.
  • Add guess_mode.
  • Add EXAMPLE_DOC_STRING.
  • Add test for guess_mode.

Also, I strongly believe the code should be able to pass tests (finger crossed 🙏)

Let me know if further changes are required.

@yiyixuxu yiyixuxu merged commit c52acaa into huggingface:main Sep 2, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…uggingface#4694)

* [ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL

Co-authored-by: Jiabin Bai 1355864570@qq.com


---------

Co-authored-by: Harutatsu Akiyama <kf.zy.qin@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…uggingface#4694)

* [ControlNet SDXL Inpainting] Support inpainting of ControlNet SDXL

Co-authored-by: Jiabin Bai 1355864570@qq.com


---------

Co-authored-by: Harutatsu Akiyama <kf.zy.qin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants