-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Dreambooth qol improvements (generating samples, device selection, multi res, checkpoint conversion) #2030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dreambooth qol improvements (generating samples, device selection, multi res, checkpoint conversion) #2030
Changes from all commits
aee5096
223fbc7
48dc919
2d04837
8540e2b
b684bbb
519aa0a
8336e20
4086ed0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -35,7 +35,6 @@ | |||||
|
|
||||||
| logger = get_logger(__name__) | ||||||
|
|
||||||
|
|
||||||
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): | ||||||
| text_encoder_config = PretrainedConfig.from_pretrained( | ||||||
| pretrained_model_name_or_path, | ||||||
|
|
@@ -55,6 +54,12 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st | |||||
| else: | ||||||
| raise ValueError(f"{model_class} is not supported.") | ||||||
|
|
||||||
| def reportVramUsage(note, device_id): | ||||||
| t = torch.cuda.get_device_properties(device_id).total_memory / 1073741824 | ||||||
| r = torch.cuda.memory_reserved(device_id) / 1073741824 | ||||||
| a = torch.cuda.memory_allocated(device_id) / 1073741824 | ||||||
| f = r-a # free inside reserved | ||||||
| print(f"({note}) Device {device_id} === Total: {t:.2f} GB, Reserved: {r:.2f} GB, Allocated: {a:.2f} GB, Free: {f:.2f} GB") | ||||||
|
|
||||||
| def parse_args(input_args=None): | ||||||
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | ||||||
|
|
@@ -289,6 +294,14 @@ def parse_args(input_args=None): | |||||
| parser.add_argument( | ||||||
| "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||||||
| ) | ||||||
| parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") | ||||||
| parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we instead maybe use:
|
||||||
| parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we call this:
|
||||||
| parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also not add this. We will handle this soon directly in accelerate as well, see: #2048 |
||||||
| parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me! |
||||||
|
|
||||||
| parser.add_argument("--device", type=str, default="cuda", help="Set the device to use for training (cuda, cuda:0, cuda:1, etc.).") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not add this. |
||||||
|
|
||||||
| if input_args is not None: | ||||||
| args = parser.parse_args(input_args) | ||||||
|
|
@@ -329,6 +342,7 @@ def __init__( | |||||
| class_prompt=None, | ||||||
| size=512, | ||||||
| center_crop=False, | ||||||
| multires_enabled=False | ||||||
| ): | ||||||
| self.size = size | ||||||
| self.center_crop = center_crop | ||||||
|
|
@@ -339,6 +353,7 @@ def __init__( | |||||
| raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") | ||||||
|
|
||||||
| self.instance_images_path = list(Path(instance_data_root).iterdir()) | ||||||
| self.instance_images_path = [path for path in self.instance_images_path if path.suffix in [".jpg", ".png", ".jpeg"]] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea! |
||||||
| self.num_instance_images = len(self.instance_images_path) | ||||||
| self.instance_prompt = instance_prompt | ||||||
| self._length = self.num_instance_images | ||||||
|
|
@@ -353,14 +368,22 @@ def __init__( | |||||
| else: | ||||||
| self.class_data_root = None | ||||||
|
|
||||||
| self.image_transforms = transforms.Compose( | ||||||
| [ | ||||||
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||||||
| transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||||||
| transforms.ToTensor(), | ||||||
| transforms.Normalize([0.5], [0.5]), | ||||||
| ] | ||||||
| ) | ||||||
| if multires_enabled: | ||||||
| self.image_transforms = transforms.Compose( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj what do you think regarding multires?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really in favour of this. I think better to keep the script simpler. |
||||||
| [ | ||||||
| transforms.ToTensor(), | ||||||
| transforms.Normalize([0.5], [0.5]), | ||||||
| ] | ||||||
| ) | ||||||
| else: | ||||||
| self.image_transforms = transforms.Compose( | ||||||
| [ | ||||||
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||||||
| transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||||||
| transforms.ToTensor(), | ||||||
| transforms.Normalize([0.5], [0.5]), | ||||||
| ] | ||||||
| ) | ||||||
|
|
||||||
| def __len__(self): | ||||||
| return self._length | ||||||
|
|
@@ -447,13 +470,23 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: | |||||
| def main(args): | ||||||
| logging_dir = Path(args.output_dir, args.logging_dir) | ||||||
|
|
||||||
| gpu = int(args.device.split(":")[1]) if args.device != "cuda" else "cuda" | ||||||
| if torch.cuda.is_available(): | ||||||
| torch.cuda.set_device(gpu) | ||||||
|
Comment on lines
+473
to
+475
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think we should have this here, this will break multi-gpu. We can set the device using |
||||||
|
|
||||||
| accelerator = Accelerator( | ||||||
| gradient_accumulation_steps=args.gradient_accumulation_steps, | ||||||
| mixed_precision=args.mixed_precision, | ||||||
| log_with=args.report_to, | ||||||
| logging_dir=logging_dir, | ||||||
| ) | ||||||
|
|
||||||
| # Accelerator device target is managed by an AcceleratorState object, grabbing | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not do this here. |
||||||
| # so we can directly set the device to run on | ||||||
| acc_state = accelerator.state | ||||||
| acc_state.device = torch.device(args.device) | ||||||
| accelerator.state = acc_state | ||||||
|
|
||||||
| # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate | ||||||
| # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. | ||||||
| # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. | ||||||
|
|
@@ -629,6 +662,7 @@ def main(args): | |||||
| tokenizer=tokenizer, | ||||||
| size=args.resolution, | ||||||
| center_crop=args.center_crop, | ||||||
| multires_enabled=args.multires | ||||||
| ) | ||||||
|
|
||||||
| train_dataloader = torch.utils.data.DataLoader( | ||||||
|
|
@@ -768,7 +802,8 @@ def main(args): | |||||
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||||||
|
|
||||||
| # Predict the noise residual | ||||||
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||||||
| with torch.cuda.amp.autocast(enabled=True): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we pass |
||||||
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||||||
|
|
||||||
| # Get the target for loss depending on the prediction type | ||||||
| if noise_scheduler.config.prediction_type == "epsilon": | ||||||
|
|
@@ -811,12 +846,63 @@ def main(args): | |||||
| progress_bar.update(1) | ||||||
| global_step += 1 | ||||||
|
|
||||||
| # Save checkpoint | ||||||
| if global_step % args.checkpointing_steps == 0: | ||||||
| if accelerator.is_main_process: | ||||||
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | ||||||
| accelerator.save_state(save_path) | ||||||
| logger.info(f"Saved state to {save_path}") | ||||||
|
|
||||||
| # Also generate and save sample images if specified | ||||||
| if args.samples_per_checkpoint > 0: | ||||||
| # Make sure any data leftover from previous interim pipeline is cleared | ||||||
| if torch.cuda.is_available(): | ||||||
| torch.cuda.empty_cache() | ||||||
|
|
||||||
| # Load current training state into a new diffusion pipeline to generate samples | ||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||
| args.pretrained_model_name_or_path, | ||||||
| unet=accelerator.unwrap_model(unet), | ||||||
| text_encoder=accelerator.unwrap_model(text_encoder), | ||||||
| torch_dtype=weight_dtype | ||||||
| ).to(accelerator.device) | ||||||
|
|
||||||
| if args.convert_checkpoints: | ||||||
| convertedOutPath = os.path.join(args.output_dir, "converted_checkpoints", f"checkpoint-{global_step}") | ||||||
| logger.info(f"Converting checkpoint-{global_step} to diffusers model at {convertedOutPath}") | ||||||
| pipeline.save_pretrained(convertedOutPath) | ||||||
|
|
||||||
| logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") | ||||||
|
|
||||||
| # Allow a statically set or random seed for for generated samples | ||||||
| sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed() | ||||||
| sampleGenerator = torch.Generator(device=accelerator.device) | ||||||
|
|
||||||
| for sampleIdx in range(args.samples_per_checkpoint): | ||||||
| sampleGenerator.manual_seed(sampleSeed) | ||||||
| # Generate samples | ||||||
| with torch.cuda.amp.autocast(enabled=True): | ||||||
| images = pipeline( | ||||||
| prompt=args.sample_prompt, | ||||||
| num_images_per_prompt=1, | ||||||
| num_inference_steps=args.sample_steps, | ||||||
| generator=sampleGenerator, | ||||||
| width=args.resolution, | ||||||
| height=args.resolution).images | ||||||
|
|
||||||
| # Save samples to 'samples' folder in output directory | ||||||
| for i, image in enumerate(images): | ||||||
| hash_image = hashlib.sha1(image.tobytes()).hexdigest() | ||||||
| os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) | ||||||
| image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_seed-{sampleSeed}_loss-{loss.detach().item()}_hash-{hash_image}.png") | ||||||
| image.save(image_filename) | ||||||
|
|
||||||
| sampleSeed += 1 | ||||||
|
|
||||||
| # Remove interim pipeline reference | ||||||
| del pipeline | ||||||
|
Comment on lines
+856
to
+903
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be simplified a lot , we have an example for logging, cf https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L902
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @subpanic, could you maybe try to use the mechanism of https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L902 here for generation? It should be a bit easier than what we currently have :-) |
||||||
|
|
||||||
|
|
||||||
| logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | ||||||
| progress_bar.set_postfix(**logs) | ||||||
| accelerator.log(logs, step=global_step) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm can we maybe align this with
diffusers/examples/dreambooth/train_dreambooth_lora.py
Line 171 in d75ad93