From aee50969ace7b327fb8c33f68ad917f992a23cb0 Mon Sep 17 00:00:00 2001 From: subpanic Date: Sat, 31 Dec 2022 15:54:39 +0000 Subject: [PATCH 1/6] add preview image support alongside checkpointing for dreambooth. --- examples/dreambooth/train_dreambooth.py | 46 +++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ddc3a608767b..65e85d5c6061 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -13,6 +13,7 @@ from torch.utils.data import Dataset from accelerate import Accelerator +import logging from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel @@ -30,6 +31,11 @@ check_min_version("0.10.0.dev0") logger = get_logger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): @@ -261,6 +267,10 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--save_samples", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") + parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating samples.") + parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample generation.") + parser.add_argument("--sample_seed", type=int, default=0, help="Seed for the sample generation.") if input_args is not None: args = parser.parse_args(input_args) @@ -630,6 +640,8 @@ def main(args): weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + + print(f" TYPE: {weight_dtype}") # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision @@ -716,7 +728,8 @@ def main(args): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with torch.cuda.amp.autocast(enabled=True): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -765,13 +778,42 @@ def main(args): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.save_samples > 0: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"Generating {args.save_samples} samples at step {global_step} with prompt: {args.sample_prompt}") + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + torch_dtype=weight_dtype, + ).to(accelerator.device) + + # Generate samples + # with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(enabled=True): + images = pipeline(prompt=args.sample_prompt, num_images_per_prompt=args.save_samples, num_inference_steps=args.sample_steps, width=args.resolution, height=args.resolution).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) + image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_loss-{loss.detach().item()}_prompt-{args.sample_prompt}_hash-{hash_image}.png") + image.save(image_filename) + + del pipeline + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - + accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. From 48dc9196f5e6e2885a38183310c52de3d84fa997 Mon Sep 17 00:00:00 2001 From: subpanic Date: Mon, 2 Jan 2023 18:23:50 +0000 Subject: [PATCH 2/6] Adding support for sample image outputs during training for each checkpoint interval. --- examples/dreambooth/train_dreambooth.py | 36 +++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 39f6d5c5e70a..996a8684a540 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -267,10 +267,10 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument("--save_samples", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") - parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating samples.") - parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample generation.") - parser.add_argument("--sample_seed", type=int, default=0, help="Seed for the sample generation.") + parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") + parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") + parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") + parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") if input_args is not None: args = parser.parse_args(input_args) @@ -772,39 +772,53 @@ def main(args): progress_bar.update(1) global_step += 1 + # Save checkpoint if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + # Also generate and save sample images if specified if args.save_samples > 0: + # Make sure any data leftover from previous interim pipeline is cleared if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"Generating {args.save_samples} samples at step {global_step} with prompt: {args.sample_prompt}") + + # Load current training state into a new diffusion pipeline to generate samples pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), - revision=args.revision, - torch_dtype=weight_dtype, + torch_dtype=weight_dtype ).to(accelerator.device) + # Allow a statically set or random seed for for generated samples + sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed() + sampleGenerator = torch.Generator(device=accelerator.device) + sampleGenerator.manual_seed(sampleSeed) + # Generate samples - # with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast(enabled=True): - images = pipeline(prompt=args.sample_prompt, num_images_per_prompt=args.save_samples, num_inference_steps=args.sample_steps, width=args.resolution, height=args.resolution).images - + images = pipeline( + prompt=args.sample_prompt, + num_images_per_prompt=args.save_samples, + num_inference_steps=args.sample_steps, + generator=sampleGenerator, + width=args.resolution, + height=args.resolution).images + + # Save samples to 'samples' folder in output directory for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_loss-{loss.detach().item()}_prompt-{args.sample_prompt}_hash-{hash_image}.png") image.save(image_filename) + # Remove interim pipeline reference del pipeline - # if torch.cuda.is_available(): - # torch.cuda.empty_cache() logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} From 2d048373aac7bd78298631cd1d08d1de7435d10c Mon Sep 17 00:00:00 2001 From: subpanic Date: Mon, 16 Jan 2023 22:13:53 +0000 Subject: [PATCH 3/6] Updating command args for sample generation --- examples/dreambooth/train_dreambooth.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 996a8684a540..e6a35c49b98a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -57,6 +57,12 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st else: raise ValueError(f"{model_class} is not supported.") +def reportVramUsage(note, device_id): + t = torch.cuda.get_device_properties(device_id).total_memory / 1073741824 + r = torch.cuda.memory_reserved(device_id) / 1073741824 + a = torch.cuda.memory_allocated(device_id) / 1073741824 + f = r-a # free inside reserved + print(f"({note}) Device {device_id} === Total: {t:.2f} GB, Reserved: {r:.2f} GB, Allocated: {a:.2f} GB, Free: {f:.2f} GB") def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -780,12 +786,12 @@ def main(args): logger.info(f"Saved state to {save_path}") # Also generate and save sample images if specified - if args.save_samples > 0: + if args.samples_per_checkpoint > 0: # Make sure any data leftover from previous interim pipeline is cleared if torch.cuda.is_available(): torch.cuda.empty_cache() - logger.info(f"Generating {args.save_samples} samples at step {global_step} with prompt: {args.sample_prompt}") + logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") # Load current training state into a new diffusion pipeline to generate samples pipeline = DiffusionPipeline.from_pretrained( @@ -804,7 +810,7 @@ def main(args): with torch.cuda.amp.autocast(enabled=True): images = pipeline( prompt=args.sample_prompt, - num_images_per_prompt=args.save_samples, + num_images_per_prompt=args.samples_per_checkpoint, num_inference_steps=args.sample_steps, generator=sampleGenerator, width=args.resolution, From 8540e2bca2e2195d0e95040f63b9218d0ea9dd47 Mon Sep 17 00:00:00 2001 From: subpanic Date: Mon, 16 Jan 2023 22:16:06 +0000 Subject: [PATCH 4/6] adding dreambooth helper scripts --- examples/dreambooth/convert_checkpoint.py | 42 +++++++++++++++++++++++ examples/dreambooth/dreambooth.sh | 32 +++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 examples/dreambooth/convert_checkpoint.py create mode 100755 examples/dreambooth/dreambooth.sh diff --git a/examples/dreambooth/convert_checkpoint.py b/examples/dreambooth/convert_checkpoint.py new file mode 100644 index 000000000000..5e08e98564ac --- /dev/null +++ b/examples/dreambooth/convert_checkpoint.py @@ -0,0 +1,42 @@ +from accelerate import Accelerator +from diffusers import DiffusionPipeline +import argparse +import os +import sys + + +parser = argparse.ArgumentParser() +parser.add_argument("--checkpoint_path", type=str, required=True) +parser.add_argument("--model_path", type=str, required=True) +parser.add_argument("--output_path", type=str, required=True) +args = parser.parse_args() + + +# if os.path.exists(args.output_path): +# raise ValueError(f"Output path {args.output_path} already exists.") +# sys.exit(1) + +os.makedirs(args.output_path, exist_ok=True) + + +# Load the pipeline with the same arguments (model, revision) that were used for training +model_id = args.model_path +pipeline = DiffusionPipeline.from_pretrained(model_id) + +accelerator = Accelerator() + +# Use text_encoder if `--train_text_encoder` was used for the initial training +unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder) + +# Restore state from a checkpoint path. You have to use the absolute path here. +accelerator.load_state(args.checkpoint_path) + +# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too) +pipeline = DiffusionPipeline.from_pretrained( + model_id, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), +) + +# Perform inference, or save, or push to the hub +pipeline.save_pretrained(args.output_path) \ No newline at end of file diff --git a/examples/dreambooth/dreambooth.sh b/examples/dreambooth/dreambooth.sh new file mode 100755 index 000000000000..3569a8d57349 --- /dev/null +++ b/examples/dreambooth/dreambooth.sh @@ -0,0 +1,32 @@ +#!/bin/sh + +MODEL_NAME=/store/sd/diffusers_models/stable-diffusion-2-1 +DATA_DIR=/store/sd/training/drbolick_768/filter_1 +OUT_DIR=/store/sd/training/out/test +ACCELERATE_LOG_LEVEL="INFO" +LOG_LEVEL="INFO" +PYTHONUNBUFFERED=1 + +python train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$DATA_DIR \ + --output_dir=$OUT_DIR \ + --instance_prompt="drbolick style" \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=5000 \ + --checkpointing_steps=10 \ + --use_8bit_adam \ + --enable_xformers_memory_efficient_attention \ + --samples_per_checkpoint=4 \ + --sample_steps=40 \ + --sample_prompt="a woman sitting among trees against the horizon. drbolick style" \ + --sample_seed=980273 + +# --gradient_checkpointing + # --class_prompt="art style" \ + # --with_prior_preservation \ No newline at end of file From 519aa0ad617c8779af471ecffab47e7431c260d9 Mon Sep 17 00:00:00 2001 From: subpanic Date: Wed, 18 Jan 2023 16:04:32 +0000 Subject: [PATCH 5/6] Cleanup & adding support for multires training and device selection --- examples/dreambooth/dreambooth.sh | 32 --------- examples/dreambooth/train_dreambooth.py | 96 ++++++++++++++++--------- 2 files changed, 61 insertions(+), 67 deletions(-) delete mode 100755 examples/dreambooth/dreambooth.sh diff --git a/examples/dreambooth/dreambooth.sh b/examples/dreambooth/dreambooth.sh deleted file mode 100755 index 3569a8d57349..000000000000 --- a/examples/dreambooth/dreambooth.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/sh - -MODEL_NAME=/store/sd/diffusers_models/stable-diffusion-2-1 -DATA_DIR=/store/sd/training/drbolick_768/filter_1 -OUT_DIR=/store/sd/training/out/test -ACCELERATE_LOG_LEVEL="INFO" -LOG_LEVEL="INFO" -PYTHONUNBUFFERED=1 - -python train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$DATA_DIR \ - --output_dir=$OUT_DIR \ - --instance_prompt="drbolick style" \ - --resolution=768 \ - --train_batch_size=1 \ - --gradient_accumulation_steps=1 \ - --learning_rate=1e-5 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --max_train_steps=5000 \ - --checkpointing_steps=10 \ - --use_8bit_adam \ - --enable_xformers_memory_efficient_attention \ - --samples_per_checkpoint=4 \ - --sample_steps=40 \ - --sample_prompt="a woman sitting among trees against the horizon. drbolick style" \ - --sample_seed=980273 - -# --gradient_checkpointing - # --class_prompt="art style" \ - # --with_prior_preservation \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 930b2922cb66..0492da6da096 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -17,7 +17,6 @@ import diffusers import transformers from accelerate import Accelerator -import logging from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel @@ -35,12 +34,6 @@ check_min_version("0.10.0.dev0") logger = get_logger(__name__) -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -305,6 +298,10 @@ def parse_args(input_args=None): parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") + parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") + parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions") + + parser.add_argument("--device", type=str, default="cuda", help="Set the device to use for training (cuda, cuda:0, cuda:1, etc.).") if input_args is not None: args = parser.parse_args(input_args) @@ -345,6 +342,7 @@ def __init__( class_prompt=None, size=512, center_crop=False, + multires_enabled=False ): self.size = size self.center_crop = center_crop @@ -355,6 +353,7 @@ def __init__( raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_images_path = [path for path in self.instance_images_path if path.suffix in [".jpg", ".png", ".jpeg"]] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -369,14 +368,22 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + if multires_enabled: + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -463,6 +470,10 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + gpu = int(args.device.split(":")[1]) if args.device != "cuda" else "cuda" + if torch.cuda.is_available(): + torch.cuda.set_device(gpu) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -470,6 +481,12 @@ def main(args): logging_dir=logging_dir, ) + # Accelerator device target is managed by an AcceleratorState object, grabbing + # so we can directly set the device to run on + acc_state = accelerator.state + acc_state.device = torch.device(args.device) + accelerator.state = acc_state + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -645,6 +662,7 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + multires_enabled=args.multires ) train_dataloader = torch.utils.data.DataLoader( @@ -840,8 +858,6 @@ def main(args): # Make sure any data leftover from previous interim pipeline is cleared if torch.cuda.is_available(): torch.cuda.empty_cache() - - logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") # Load current training state into a new diffusion pipeline to generate samples pipeline = DiffusionPipeline.from_pretrained( @@ -851,27 +867,37 @@ def main(args): torch_dtype=weight_dtype ).to(accelerator.device) + if args.convert_checkpoints: + convertedOutPath = os.path.join(args.output_dir, "converted_checkpoints", f"checkpoint-{global_step}") + logger.info(f"Converting checkpoint-{global_step} to diffusers model at {convertedOutPath}") + pipeline.save_pretrained(convertedOutPath) + + logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") + # Allow a statically set or random seed for for generated samples sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed() sampleGenerator = torch.Generator(device=accelerator.device) - sampleGenerator.manual_seed(sampleSeed) - - # Generate samples - with torch.cuda.amp.autocast(enabled=True): - images = pipeline( - prompt=args.sample_prompt, - num_images_per_prompt=args.samples_per_checkpoint, - num_inference_steps=args.sample_steps, - generator=sampleGenerator, - width=args.resolution, - height=args.resolution).images - - # Save samples to 'samples' folder in output directory - for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() - os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) - image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_loss-{loss.detach().item()}_prompt-{args.sample_prompt}_hash-{hash_image}.png") - image.save(image_filename) + + for sampleIdx in range(args.samples_per_checkpoint): + sampleGenerator.manual_seed(sampleSeed) + # Generate samples + with torch.cuda.amp.autocast(enabled=True): + images = pipeline( + prompt=args.sample_prompt, + num_images_per_prompt=1, + num_inference_steps=args.sample_steps, + generator=sampleGenerator, + width=args.resolution, + height=args.resolution).images + + # Save samples to 'samples' folder in output directory + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) + image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_seed-{sampleSeed}_loss-{loss.detach().item()}_hash-{hash_image}.png") + image.save(image_filename) + + sampleSeed += 1 # Remove interim pipeline reference del pipeline From 4086ed0114c1cd52faae6f5639463fbd18295bd5 Mon Sep 17 00:00:00 2001 From: subpanic Date: Wed, 18 Jan 2023 16:08:18 +0000 Subject: [PATCH 6/6] Additional cleanup. --- examples/dreambooth/convert_checkpoint.py | 42 ----------------------- 1 file changed, 42 deletions(-) delete mode 100644 examples/dreambooth/convert_checkpoint.py diff --git a/examples/dreambooth/convert_checkpoint.py b/examples/dreambooth/convert_checkpoint.py deleted file mode 100644 index 5e08e98564ac..000000000000 --- a/examples/dreambooth/convert_checkpoint.py +++ /dev/null @@ -1,42 +0,0 @@ -from accelerate import Accelerator -from diffusers import DiffusionPipeline -import argparse -import os -import sys - - -parser = argparse.ArgumentParser() -parser.add_argument("--checkpoint_path", type=str, required=True) -parser.add_argument("--model_path", type=str, required=True) -parser.add_argument("--output_path", type=str, required=True) -args = parser.parse_args() - - -# if os.path.exists(args.output_path): -# raise ValueError(f"Output path {args.output_path} already exists.") -# sys.exit(1) - -os.makedirs(args.output_path, exist_ok=True) - - -# Load the pipeline with the same arguments (model, revision) that were used for training -model_id = args.model_path -pipeline = DiffusionPipeline.from_pretrained(model_id) - -accelerator = Accelerator() - -# Use text_encoder if `--train_text_encoder` was used for the initial training -unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder) - -# Restore state from a checkpoint path. You have to use the absolute path here. -accelerator.load_state(args.checkpoint_path) - -# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too) -pipeline = DiffusionPipeline.from_pretrained( - model_id, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), -) - -# Perform inference, or save, or push to the hub -pipeline.save_pretrained(args.output_path) \ No newline at end of file