diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index a16213639526..dbc1b27e88be 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -45,6 +45,8 @@ PNDMScheduler, PriorTransformer, StableDiffusionControlNetPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, @@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt( image_size: int = 512, prediction_type: str = None, model_type: str = None, + is_img2img: bool = False, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, @@ -1018,6 +1021,8 @@ def download_from_original_stable_diffusion_ckpt( model_type (`str`, *optional*, defaults to `None`): The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for @@ -1193,16 +1198,44 @@ def download_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + if ( + hasattr(original_config, "model") + and hasattr(original_config.model, "target") + and "LatentInpaintDiffusion" in original_config.model.target + ): + pipe = StableDiffusionInpaintPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + if is_img2img: + pipe = StableDiffusionImg2ImgPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device @@ -1293,15 +1326,41 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) + if ( + hasattr(original_config, "model") + and hasattr(original_config.model, "target") + and "LatentInpaintDiffusion" in original_config.model.target + ): + pipe = StableDiffusionInpaintPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + if is_img2img: + pipe = StableDiffusionImg2ImgPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)