From a7238746e24e4850666c8d4e2bcb1b3a02fb73f8 Mon Sep 17 00:00:00 2001 From: William Berman Date: Fri, 2 Jun 2023 18:33:26 -0700 Subject: [PATCH] small tweaks for parsing thibaudz controlnet checkpoints --- ...onvert_original_controlnet_to_diffusers.py | 18 ++++ .../stable_diffusion/convert_from_ckpt.py | 99 +++++++++++++------ 2 files changed, 87 insertions(+), 30 deletions(-) diff --git a/scripts/convert_original_controlnet_to_diffusers.py b/scripts/convert_original_controlnet_to_diffusers.py index a9e05abd4cf1..9466bd27234c 100644 --- a/scripts/convert_original_controlnet_to_diffusers.py +++ b/scripts/convert_original_controlnet_to_diffusers.py @@ -75,6 +75,22 @@ ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + args = parser.parse_args() controlnet = download_controlnet_from_original_ckpt( @@ -86,6 +102,8 @@ upcast_attention=args.upcast_attention, from_safetensors=args.from_safetensors, device=args.device, + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, ) controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 7ba1bbd996db..e59b91e486f5 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config): return config -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): """ Takes a state dict and a config, and returns a converted checkpoint. """ - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." + if skip_extract_state_dict: + unet_state_dict = checkpoint else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} @@ -956,17 +961,42 @@ def stable_unclip_image_noising_components( def convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + controlnet_model = ControlNetModel(**ctrlnet_config) + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) @@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt( upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) @@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt( raise ValueError("`control_stage_config` not present in original config") controlnet_model = convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, ) return controlnet_model