Skip to content

Conversation

@gkorepanov
Copy link

What does this PR do?

Follow-up for #4038 with fixes which allow switching CFG and "guess_mode" in SD XL controlnet pipeline

Who can review?

@sayakpaul can you please suggest if I need to add some tests here? Also you mentioned that this bug blocks multi controlnet support, should I also try to add it in this PR?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sayakpaul
Copy link
Member

Also you mentioned that this bug blocks multi controlnet support, should I also try to add it in this PR?

@gkorepanov sure, let's try to merge that in.

Comment on lines +883 to +943
scaled_latents = self.scheduler.scale_model_input(latents, t)
latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I have an explanation on why this differs from

latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to make the code around switching CFG more clear, I don't see the point behind scaling latents once for main unet and then the second time for controlnet. So I made a change which is aligned with all inputs (prompt embeds, text_embeds, time_ids): we first prepare the standard inputs for controlnet, and then expand them (torch.cat([...]*2) for unet if it is required by CFG

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Clean.

Copy link
Collaborator

@yiyixuxu yiyixuxu Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just took a closer look here. Is it possible to do exactly the same as here? https://github.com/huggingface/diffusers/blob/80871ac5971fe7e708befa3b553463c4e61b22ab/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L938C19-L938C19

I think the logic is very clear there:

  1. all the code create inputs for the controlnet_model is addressed within the if ... else ... statement
  2. it is very clear to me that only when guess_mode and do_classifier_free_guidance it will differ from our regular Unet model input

@sayakpaul
Copy link
Member

sayakpaul commented Jul 20, 2023

@gkorepanov this is a great start!

Let's add a few test cases to ensure feature robustness :)

Also, we need to ensure the existing tests don't fail. This is from the CI:

=========================== short test summary info ============================
FAILED tests/pipelines/controlnet/test_controlnet_sdxl.py::ControlNetPipelineSDXLFastTests::test_inference_batch_consistent - RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x56 and 80x128)
FAILED tests/pipelines/controlnet/test_controlnet_sdxl.py::ControlNetPipelineSDXLFastTests::test_inference_batch_single_identical - RuntimeError: mat1 and mat2 shapes cannot be multiplied (6x48 and 80x128)
FAILED tests/pipelines/controlnet/test_controlnet_sdxl.py::ControlNetPipelineSDXLFastTests::test_num_images_per_prompt - RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x56 and 80x128)
==== 3 failed, 1061 passed, 625 skipped, 1380 warnings in 752.68s (0:12:32) ====

@sayakpaul
Copy link
Member

@gkorepanov let us know if there's anything we can do to accelerate this PR :-)

@gkorepanov
Copy link
Author

@gkorepanov let us know if there's anything we can do to accelerate this PR :-)

Sorry for long delay, my availability has been limited recently. I've managed to address the issues with current tests. However, I'm uncertain about adding new tests since we lack an official ControlNet model for SDXL to run a comprehensive test. So I've merely added tests to verify the pipeline's basic functionality with disabled CFG and enabled guess mode: https://github.com/huggingface/diffusers/pull/4155/files#diff-fa8f17e303ce826636a7e4038a33ea652b1f9279cce289ddaacb4332efd0b4f9R263

Please let me know if I should add new tests in some different way.

Regarding multi-controlnet, I'm unsure if I'll be able to dedicate time to work on it in the near future :(

@sayakpaul
Copy link
Member

Will get to reviewing this, soon! Thank you for your hard work!

@sayakpaul
Copy link
Member

@gkorepanov thanks for your efforts! The PR looks good to me.

However, I'm uncertain about adding new tests since we lack an official ControlNet model for SDXL to run a comprehensive test.

I created this dummy ControlNet pipeline: https://huggingface.co/hf-internal-testing/dummy-sdxl-controlnet-pipe. Will this suffice for testing?

@sayakpaul
Copy link
Member

Hey @gkorepanov. I think the PR looks good in its current form. Let me also check with @yiyixuxu. Yiyi, could you also give this a look?

@sayakpaul sayakpaul requested a review from yiyixuxu August 14, 2023 03:46
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4


class ControlNetPipelineSDXLGuessModeFastTests(ControlNetPipelineSDXLFastTests):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test all four scenarios here?

  • guess_mode == True & CFG == True
  • guess_model ==False and CFG == True
  • guess_model == False & CFG == False
  • guess_mode == True and CFG == False

I found the logic a little bit complex so let's first add test here to make sure it works as expected here first

Copy link
Author

@gkorepanov gkorepanov Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, If you have any idea on how we could simplify the logic, it would be welcome!

Regarding tests, do you think it is sufficient to run functional tests (i.e. run pipeline with all combinations of parameters and make sure it does not fail) or we need to run heavy tests with real checkpoint loading and checking the outputs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functional tests are fine here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
Added tests for all combinations of params: 87b7cfe

I was not sure how to add multiple options given unittest limited functionality in parametrisation, so feel free to suggest better options

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make use of parameterized here no? A couple of our test cases use it.

Let me know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know.

I have already used it, please have a look at the commit 87b7cfe.

Though parameterized had a bug related to classes inheritance, I have used a workaround, but I don't like it too much

@sayakpaul
Copy link
Member

@yiyixuxu @DN6 could you give this a look? I think this is important as it fixes support for guess mode in SDXL ControlNet.

@yiyixuxu
Copy link
Collaborator

@gkorepanov
i made a PR here #4799 - I think it's doing the same thing, but let me know if it's not ;)

If it is possible, we want to keep the logic of how we handle the guess_mode in the current implementation, and I explained my reason here #4155 (comment)

@gkorepanov
Copy link
Author

If it is possible, we want to keep the logic of how we handle the guess_mode in the current implementation

No problem, up to you. Closing in favour of #4799

@gkorepanov gkorepanov closed this Aug 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants