diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index aac720ffa474..0492da6da096 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -35,7 +35,6 @@ logger = get_logger(__name__) - def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -55,6 +54,12 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st else: raise ValueError(f"{model_class} is not supported.") +def reportVramUsage(note, device_id): + t = torch.cuda.get_device_properties(device_id).total_memory / 1073741824 + r = torch.cuda.memory_reserved(device_id) / 1073741824 + a = torch.cuda.memory_allocated(device_id) / 1073741824 + f = r-a # free inside reserved + print(f"({note}) Device {device_id} === Total: {t:.2f} GB, Reserved: {r:.2f} GB, Allocated: {a:.2f} GB, Free: {f:.2f} GB") def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -289,6 +294,14 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") + parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") + parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") + parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") + parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") + parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions") + + parser.add_argument("--device", type=str, default="cuda", help="Set the device to use for training (cuda, cuda:0, cuda:1, etc.).") if input_args is not None: args = parser.parse_args(input_args) @@ -329,6 +342,7 @@ def __init__( class_prompt=None, size=512, center_crop=False, + multires_enabled=False ): self.size = size self.center_crop = center_crop @@ -339,6 +353,7 @@ def __init__( raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_images_path = [path for path in self.instance_images_path if path.suffix in [".jpg", ".png", ".jpeg"]] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -353,14 +368,22 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + if multires_enabled: + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -447,6 +470,10 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + gpu = int(args.device.split(":")[1]) if args.device != "cuda" else "cuda" + if torch.cuda.is_available(): + torch.cuda.set_device(gpu) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -454,6 +481,12 @@ def main(args): logging_dir=logging_dir, ) + # Accelerator device target is managed by an AcceleratorState object, grabbing + # so we can directly set the device to run on + acc_state = accelerator.state + acc_state.device = torch.device(args.device) + accelerator.state = acc_state + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -629,6 +662,7 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + multires_enabled=args.multires ) train_dataloader = torch.utils.data.DataLoader( @@ -768,7 +802,8 @@ def main(args): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with torch.cuda.amp.autocast(enabled=True): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -811,12 +846,63 @@ def main(args): progress_bar.update(1) global_step += 1 + # Save checkpoint if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + # Also generate and save sample images if specified + if args.samples_per_checkpoint > 0: + # Make sure any data leftover from previous interim pipeline is cleared + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Load current training state into a new diffusion pipeline to generate samples + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + torch_dtype=weight_dtype + ).to(accelerator.device) + + if args.convert_checkpoints: + convertedOutPath = os.path.join(args.output_dir, "converted_checkpoints", f"checkpoint-{global_step}") + logger.info(f"Converting checkpoint-{global_step} to diffusers model at {convertedOutPath}") + pipeline.save_pretrained(convertedOutPath) + + logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") + + # Allow a statically set or random seed for for generated samples + sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed() + sampleGenerator = torch.Generator(device=accelerator.device) + + for sampleIdx in range(args.samples_per_checkpoint): + sampleGenerator.manual_seed(sampleSeed) + # Generate samples + with torch.cuda.amp.autocast(enabled=True): + images = pipeline( + prompt=args.sample_prompt, + num_images_per_prompt=1, + num_inference_steps=args.sample_steps, + generator=sampleGenerator, + width=args.resolution, + height=args.resolution).images + + # Save samples to 'samples' folder in output directory + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) + image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_seed-{sampleSeed}_loss-{loss.detach().item()}_hash-{hash_image}.png") + image.save(image_filename) + + sampleSeed += 1 + + # Remove interim pipeline reference + del pipeline + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step)