From c53fc3f77be8b2cb9c8bd17d967bba71a0b7e5a4 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 16 Mar 2023 12:43:44 +0530 Subject: [PATCH 1/6] [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model --- .../stable_diffusion/convert_from_ckpt.py | 89 +++++++++++++++---- 1 file changed, 70 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index ef4598433f82..4460d757c333 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -46,6 +46,8 @@ PriorTransformer, StableDiffusionControlNetPipeline, StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, UnCLIPScheduler, @@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt( image_size: int = 512, prediction_type: str = None, model_type: str = None, + is_img2img: bool = False, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, @@ -1017,6 +1020,8 @@ def download_from_original_stable_diffusion_ckpt( model_type (`str`, *optional*, defaults to `None`): The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for @@ -1190,16 +1195,40 @@ def download_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + if "LatentInpaintDiffusion" in original_config.model.target: + pipe = StableDiffusionInpaintPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + if is_img2img: + pipe = StableDiffusionImg2ImgPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device @@ -1285,15 +1314,37 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) + if "LatentInpaintDiffusion" in original_config.model.target: + pipe = StableDiffusionInpaintPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + if is_img2img: + pipe = StableDiffusionImg2ImgPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) From d3d6c324a0617d339cb5a63ea0deab47e4f5d8d9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 23 Mar 2023 19:28:59 +0530 Subject: [PATCH 2/6] Address review comment from PR --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 4460d757c333..8a2c2f35ff28 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1195,7 +1195,7 @@ def download_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - if "LatentInpaintDiffusion" in original_config.model.target: + if hasattr(original_config, "model") and hasattr(original_config.model, "target") and "LatentInpaintDiffusion" in original_config.model.target: pipe = StableDiffusionInpaintPipeline( vae=vae, text_encoder=text_model, @@ -1314,7 +1314,7 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) else: - if "LatentInpaintDiffusion" in original_config.model.target: + if hasattr(original_config, "model") and hasattr(original_config.model, "target") and "LatentInpaintDiffusion" in original_config.model.target: pipe = StableDiffusionInpaintPipeline( vae=vae, text_encoder=text_model, From 11accc797a2e09991a2821223fd49b5dfdaccc38 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 30 Mar 2023 09:58:42 +0530 Subject: [PATCH 3/6] PyLint formatting --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 8a2c2f35ff28..e6d4dce169c4 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1195,7 +1195,11 @@ def download_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - if hasattr(original_config, "model") and hasattr(original_config.model, "target") and "LatentInpaintDiffusion" in original_config.model.target: + if ( + hasattr(original_config, "model") + and hasattr(original_config.model, "target") + and "LatentInpaintDiffusion" in original_config.model.target + ): pipe = StableDiffusionInpaintPipeline( vae=vae, text_encoder=text_model, @@ -1314,7 +1318,11 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) else: - if hasattr(original_config, "model") and hasattr(original_config.model, "target") and "LatentInpaintDiffusion" in original_config.model.target: + if ( + hasattr(original_config, "model") + and hasattr(original_config.model, "target") + and "LatentInpaintDiffusion" in original_config.model.target + ): pipe = StableDiffusionInpaintPipeline( vae=vae, text_encoder=text_model, From 6ab1317c6d7c853d028c9b24e8263c14534b35e5 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 30 Mar 2023 10:01:32 +0530 Subject: [PATCH 4/6] Some more pylint fixes, unrelated to our change --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index e6d4dce169c4..447cd828a38f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1075,7 +1075,9 @@ def download_from_original_stable_diffusion_ckpt( key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" # model_type = "v1" - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: # model_type = "v2" From 970240050864122878191b7258e5671abed80906 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 30 Mar 2023 10:10:47 +0530 Subject: [PATCH 5/6] Another pylint fix --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 447cd828a38f..e6d4dce169c4 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1075,9 +1075,7 @@ def download_from_original_stable_diffusion_ckpt( key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" # model_type = "v1" - config_url = ( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: # model_type = "v2" From 464831b7c773fa66cd0c0c880684817084b22735 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 30 Mar 2023 10:20:51 +0530 Subject: [PATCH 6/6] Styling fix --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 326c8a02f910..dbc1b27e88be 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -45,9 +45,9 @@ PNDMScheduler, PriorTransformer, StableDiffusionControlNetPipeline, - StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, UnCLIPScheduler,