diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 89ce88455e20..561ae740738c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1389,7 +1389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) extract_ema = kwargs.pop("extract_ema", False) - image_size = kwargs.pop("image_size", 512) + image_size = kwargs.pop("image_size", None) scheduler_type = kwargs.pop("scheduler_type", "pndm") num_in_channels = kwargs.pop("num_in_channels", None) upcast_attention = kwargs.pop("upcast_attention", None) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 82f270955cc5..d4b119caaf93 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -24,6 +24,7 @@ AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, + CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, @@ -48,7 +49,7 @@ PNDMScheduler, UnCLIPScheduler, ) -from ...utils import is_omegaconf_available, is_safetensors_available, logging +from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder @@ -57,6 +58,10 @@ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -770,11 +775,12 @@ def _copy_layers(hf_layers, pt_layers): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): - text_model = ( - CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - if text_encoder is None - else text_encoder - ) + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + config = CLIPTextConfig.from_pretrained(config_name) + + with init_empty_weights(): + text_model = CLIPTextModel(config) keys = list(checkpoint.keys()) @@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder if key.startswith(prefix): text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - text_model.load_state_dict(text_model_dict) + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model @@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint): return model -def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - text_model = CLIPTextModelWithProjection.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 - ) + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + with init_empty_weights(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) keys = list(checkpoint.keys()) + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + text_model_dict = {} if prefix + "text_projection" in checkpoint: @@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: - # if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer - # continue + if key in keys_to_ignore: + continue if key[len(prefix) :] in textenc_conversion_map: if key.endswith("text_projection"): value = checkpoint[key].T @@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): text_model_dict[new_key] = checkpoint[key] - text_model.load_state_dict(text_model_dict) + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model @@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint( def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, - image_size: int = 512, + image_size: Optional[int] = None, prediction_type: str = None, model_type: str = None, extract_ema: bool = False, @@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt( LDMTextToImagePipeline, PaintByExamplePipeline, StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, @@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt( if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - from safetensors import safe_open + from safetensors.torch import load_file as safe_load - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) + checkpoint = safe_load(checkpoint_path, device="cpu") else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt( if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: - logger.warning("global_step key not found in model") + logger.debug("global_step key not found in model") global_step = None # NOTE: this while loop isn't great but this controlnet checkpoint has one additional @@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt( model_type = "SDXL" else: model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( "parameterization" in original_config["model"]["params"] @@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt( num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 if model_type in ["SDXL", "SDXL-Refiner"]: - image_size = 1024 scheduler_dict = { "beta_schedule": "scaled_linear", "beta_start": 0.00085, @@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt( } scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler_type = "euler" - vae_path = "stabilityai/sdxl-vae" else: beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 @@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt( # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention - unet = UNet2DConditionModel(**unet_config) + with init_empty_weights(): + unet = UNet2DConditionModel(**unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) - unet.load_state_dict(converted_unet_checkpoint) + + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) # Convert the VAE model. if vae_path is None: vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) else: vae = AutoencoderKL.from_pretrained(vae_path) if model_type == "FrozenOpenCLIPEmbedder": - text_model = convert_open_clip_checkpoint(checkpoint) + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: @@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt( tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") - text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) pipe = StableDiffusionXLPipeline( vae=vae, @@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt( tokenizer = None text_encoder = None tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") - text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.") + + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) pipe = StableDiffusionXLImg2ImgPipeline( vae=vae, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d958f0e3fb72..d719fd141983 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool return mask, masked_image -class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-guided image inpainting using Stable Diffusion. diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e7b084acb280..a3b331d17a51 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -20,17 +20,20 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, + DDIMScheduler, DPMSolverMultistepScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel, ) +from diffusers.models.attention_processor import AttnProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import ( @@ -512,6 +515,42 @@ def test_stable_diffusion_simple_inpaint_ddim(self): assert np.abs(expected_slice - image_slice).max() < 6e-4 + def test_download_local(self): + filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt") + + pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 1 + image_out = pipe(**inputs).images[0] + + assert image_out.shape == (512, 512, 3) + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" + + pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_attn_processor(AttnProcessor()) + pipe.to("cuda") + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 5 + image_ckpt = pipe(**inputs).images[0] + + pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_attn_processor(AttnProcessor()) + pipe.to("cuda") + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 5 + image = pipe(**inputs).images[0] + + assert np.max(np.abs(image - image_ckpt)) < 1e-4 + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 67486e61dbef..1db2e18e5b19 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -19,6 +19,7 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -29,6 +30,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu @@ -426,6 +428,40 @@ def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): assert image.shape == (768, 768, 3) assert np.abs(expected_image - image).max() < 7.5e-1 + def test_download_local(self): + filename = hf_hub_download("stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.safetensors") + + pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + image_out = pipe("test", num_inference_steps=1, output_type="np").images[0] + + assert image_out.shape == (768, 768, 3) + + def test_download_ckpt_diff_format_is_same(self): + single_file_path = ( + "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" + ) + + pipe_single = StableDiffusionPipeline.from_single_file(single_file_path) + pipe_single.scheduler = DDIMScheduler.from_config(pipe_single.scheduler.config) + pipe_single.unet.set_attn_processor(AttnProcessor()) + pipe_single.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(0) + image_ckpt = pipe_single("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0] + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_attn_processor(AttnProcessor()) + pipe.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0] + + assert np.max(np.abs(image - image_ckpt)) < 1e-3 + def test_stable_diffusion_text2img_intermediate_state_v_pred(self): number_of_steps = 0