From fc83f21705b4c851a0565fb152715fefd6096faa Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 4 Sep 2023 06:13:35 +0000 Subject: [PATCH 1/2] fix --- src/diffusers/pipelines/pipeline_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a041e6cd003c..110c97acdcdf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1147,8 +1147,22 @@ def load_module(name, value): "variant": variant, "use_safetensors": use_safetensors, } + + def get_connected_passed_kwargs(prefix): + connected_passed_class_obj = { + k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix + } + connected_passed_pipe_kwargs = { + k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix + } + + connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} + return connected_passed_kwargs + connected_pipes = { - prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy()) + prefix: DiffusionPipeline.from_pretrained( + repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) + ) for prefix, repo_id in connected_pipes.items() if repo_id is not None } From 064ded6a01bae3ce7efcce7c2b7d47e7ea098014 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 4 Sep 2023 08:34:57 +0000 Subject: [PATCH 2/2] add test --- tests/pipelines/test_pipelines_combined.py | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_combined.py b/tests/pipelines/test_pipelines_combined.py index 925fa6e3c24d..c394ec0b1691 100644 --- a/tests/pipelines/test_pipelines_combined.py +++ b/tests/pipelines/test_pipelines_combined.py @@ -18,7 +18,13 @@ import torch from huggingface_hub import ModelCard -from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline +from diffusers import ( + DDPMScheduler, + DiffusionPipeline, + KandinskyV22CombinedPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorPipeline, +) from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS @@ -101,3 +107,22 @@ def test_load_connected_checkpoint_default(self): assert dict(component.config) == dict(comp.config) else: assert component.__class__ == comp.__class__ + + def test_load_connected_checkpoint_with_passed_obj(self): + pipeline = KandinskyV22CombinedPipeline.from_pretrained( + "hf-internal-testing/tiny-random-kandinsky-v22-decoder" + ) + prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config) + scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) + + # make sure we pass a different scheduler and prior_scheduler + assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__ + assert pipeline.scheduler.__class__ != scheduler.__class__ + + pipeline_new = KandinskyV22CombinedPipeline.from_pretrained( + "hf-internal-testing/tiny-random-kandinsky-v22-decoder", + prior_scheduler=prior_scheduler, + scheduler=scheduler, + ) + assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config) + assert dict(pipeline_new.scheduler.config) == dict(scheduler.config)