-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Dreambooth qol improvements (generating samples, device selection, multi res, checkpoint conversion) #2030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
cc @patil-suraj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @subpanic , I left a few comments below.
Note that the goal of the example scripts is to have simple scripts that cover the most common use cases and are easier to read and modify if users wants to change something. For additional features like multires we encourage users to take the script and adjust on their own, so as to keep the example here simpler. Hope you understand.
Support for optionally generating sample image outputs for each checkpoint.
That's a good point, @patrickvonplaten is working on adding logging to all training examples. @patrickvonplaten maybe we could add an argument like validation_steps so users could sync checkpointing and generating.
| raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") | ||
|
|
||
| self.instance_images_path = list(Path(instance_data_root).iterdir()) | ||
| self.instance_images_path = [path for path in self.instance_images_path if path.suffix in [".jpg", ".png", ".jpeg"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
| gpu = int(args.device.split(":")[1]) if args.device != "cuda" else "cuda" | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.set_device(gpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think we should have this here, this will break multi-gpu. We can set the device using CUDA_VISIBLE_DEVICES env variable. Which is also recommended by torch, cf https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html
|
|
||
| # Predict the noise residual | ||
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| with torch.cuda.amp.autocast(enabled=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we pass --mixed_precision="fp16" accelerate automatically uses autocast; we should not hardcode this here.
| # Also generate and save sample images if specified | ||
| if args.samples_per_checkpoint > 0: | ||
| # Make sure any data leftover from previous interim pipeline is cleared | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Load current training state into a new diffusion pipeline to generate samples | ||
| pipeline = DiffusionPipeline.from_pretrained( | ||
| args.pretrained_model_name_or_path, | ||
| unet=accelerator.unwrap_model(unet), | ||
| text_encoder=accelerator.unwrap_model(text_encoder), | ||
| torch_dtype=weight_dtype | ||
| ).to(accelerator.device) | ||
|
|
||
| if args.convert_checkpoints: | ||
| convertedOutPath = os.path.join(args.output_dir, "converted_checkpoints", f"checkpoint-{global_step}") | ||
| logger.info(f"Converting checkpoint-{global_step} to diffusers model at {convertedOutPath}") | ||
| pipeline.save_pretrained(convertedOutPath) | ||
|
|
||
| logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}") | ||
|
|
||
| # Allow a statically set or random seed for for generated samples | ||
| sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed() | ||
| sampleGenerator = torch.Generator(device=accelerator.device) | ||
|
|
||
| for sampleIdx in range(args.samples_per_checkpoint): | ||
| sampleGenerator.manual_seed(sampleSeed) | ||
| # Generate samples | ||
| with torch.cuda.amp.autocast(enabled=True): | ||
| images = pipeline( | ||
| prompt=args.sample_prompt, | ||
| num_images_per_prompt=1, | ||
| num_inference_steps=args.sample_steps, | ||
| generator=sampleGenerator, | ||
| width=args.resolution, | ||
| height=args.resolution).images | ||
|
|
||
| # Save samples to 'samples' folder in output directory | ||
| for i, image in enumerate(images): | ||
| hash_image = hashlib.sha1(image.tobytes()).hexdigest() | ||
| os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True) | ||
| image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_seed-{sampleSeed}_loss-{loss.detach().item()}_hash-{hash_image}.png") | ||
| image.save(image_filename) | ||
|
|
||
| sampleSeed += 1 | ||
|
|
||
| # Remove interim pipeline reference | ||
| del pipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified a lot , we have an example for logging, cf https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L902
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @subpanic, could you maybe try to use the mechanism of https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L902 here for generation? It should be a bit easier than what we currently have :-)
| logging_dir=logging_dir, | ||
| ) | ||
|
|
||
| # Accelerator device target is managed by an AcceleratorState object, grabbing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not do this here. accelerate is responsible for device handling :-)
| ] | ||
| ) | ||
| if multires_enabled: | ||
| self.image_transforms = transforms.Compose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj what do you think regarding multires?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really in favour of this. I think better to keep the script simpler.
| parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") | ||
| parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions") | ||
|
|
||
| parser.add_argument("--device", type=str, default="cuda", help="Set the device to use for training (cuda, cuda:0, cuda:1, etc.).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not add this. accelerate should handle this :-)
| parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") | ||
| parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") | ||
| parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") | ||
| parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also not add this. We will handle this soon directly in accelerate as well, see: #2048
| "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||
| ) | ||
| parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") | ||
| parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead maybe use:
| "--validation_epochs", |
| parser.add_argument( | ||
| "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||
| ) | ||
| parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm can we maybe align this with
| "--validation_epochs", |
| ) | ||
| parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") | ||
| parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") | ||
| parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we call this:
| "--validation_prompt", |
| parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.") | ||
| parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.") | ||
| parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") | ||
| parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") | |
| parser.add_argument("--validation_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") |
| parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.") | ||
| parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed") | ||
| parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure") | ||
| parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR. It would be great if we could align things here a bit with: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L159 if possible :-)
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Some additional features and quality of life improvements for the dream booth script. Opening a PR in case some/all of these are worth merging. I believe all additions should be non-destructive for existing script users assuming autocast isn't a problem (but it might be, so worth investigating).
Some questions around areas I'm not 100% on for broader audience:
set_devicefor torch and directly access theaccelerator.stateand set the device. Unsure if there's a cleaner way to do thiswith torch.cuda.amp.autocast(enabled=True):. I have not noticed any performance or obvious training issues due to this but if someone knows why this is a bad idea it would be greatly appreciated if they could help work out how to avoid the type mismatch issues without autocrats