diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c37ee87d8899..142d3a842199 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -33,7 +33,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, model_info, upload_folder +from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version from PIL import Image @@ -756,16 +756,6 @@ def __getitem__(self, index): return example -def model_has_vae(args): - config_file_name = os.path.join("vae", AutoencoderKL.config_name) - if os.path.isdir(args.pretrained_model_name_or_path): - config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) - return os.path.isfile(config_file_name) - else: - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings - return any(file.rfilename == config_file_name for file in files_in_repo) - - def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): if tokenizer_max_length is not None: max_length = tokenizer_max_length @@ -920,11 +910,13 @@ def main(args): args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) - if model_has_vae(args): + try: vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) - else: + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here vae = None unet = UNet2DConditionModel.from_pretrained(