Skip to content
106 changes: 96 additions & 10 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

logger = get_logger(__name__)


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
Expand All @@ -55,6 +54,12 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
else:
raise ValueError(f"{model_class} is not supported.")

def reportVramUsage(note, device_id):
t = torch.cuda.get_device_properties(device_id).total_memory / 1073741824
r = torch.cuda.memory_reserved(device_id) / 1073741824
a = torch.cuda.memory_allocated(device_id) / 1073741824
f = r-a # free inside reserved
print(f"({note}) Device {device_id} === Total: {t:.2f} GB, Reserved: {r:.2f} GB, Allocated: {a:.2f} GB, Free: {f:.2f} GB")

def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
Expand Down Expand Up @@ -289,6 +294,14 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--samples_per_checkpoint", type=int, default=0, help="Whether or not to save samples for every checkpoint specified by --checkpointing_steps.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can we maybe align this with

and not have this argument?

parser.add_argument("--sample_steps", type=int, default=40, help="Number of steps for generating sample images.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead maybe use:

?

parser.add_argument("--sample_prompt", type=str, default=None, help="Prompt to use for sample image generation.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we call this:

maybe?

parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument("--sample_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed")
parser.add_argument("--validation_seed", type=int, default=-1, help="Seed for the per-checkpoint sample image generation. -1 to select random seed")

parser.add_argument("--convert_checkpoints", action="store_true", help="Auto-convert checkpoints to an inference ready structure")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also not add this. We will handle this soon directly in accelerate as well, see: #2048

parser.add_argument("--multires", action="store_true", help="Disables dataset image transforms. Allows training on image datasets of arbitrary resolutions")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!


parser.add_argument("--device", type=str, default="cuda", help="Set the device to use for training (cuda, cuda:0, cuda:1, etc.).")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not add this. accelerate should handle this :-)


if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -329,6 +342,7 @@ def __init__(
class_prompt=None,
size=512,
center_crop=False,
multires_enabled=False
):
self.size = size
self.center_crop = center_crop
Expand All @@ -339,6 +353,7 @@ def __init__(
raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.instance_images_path = [path for path in self.instance_images_path if path.suffix in [".jpg", ".png", ".jpeg"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
Expand All @@ -353,14 +368,22 @@ def __init__(
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
if multires_enabled:
self.image_transforms = transforms.Compose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj what do you think regarding multires?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really in favour of this. I think better to keep the script simpler.

[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
else:
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length
Expand Down Expand Up @@ -447,13 +470,23 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

gpu = int(args.device.split(":")[1]) if args.device != "cuda" else "cuda"
if torch.cuda.is_available():
torch.cuda.set_device(gpu)
Comment on lines +473 to +475
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we should have this here, this will break multi-gpu. We can set the device using CUDA_VISIBLE_DEVICES env variable. Which is also recommended by torch, cf https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html


accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)

# Accelerator device target is managed by an AcceleratorState object, grabbing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do this here. accelerate is responsible for device handling :-)

# so we can directly set the device to run on
acc_state = accelerator.state
acc_state.device = torch.device(args.device)
accelerator.state = acc_state

# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
Expand Down Expand Up @@ -629,6 +662,7 @@ def main(args):
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
multires_enabled=args.multires
)

train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -768,7 +802,8 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with torch.cuda.amp.autocast(enabled=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we pass --mixed_precision="fp16" accelerate automatically uses autocast; we should not hardcode this here.

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
Expand Down Expand Up @@ -811,12 +846,63 @@ def main(args):
progress_bar.update(1)
global_step += 1

# Save checkpoint
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

# Also generate and save sample images if specified
if args.samples_per_checkpoint > 0:
# Make sure any data leftover from previous interim pipeline is cleared
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Load current training state into a new diffusion pipeline to generate samples
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
torch_dtype=weight_dtype
).to(accelerator.device)

if args.convert_checkpoints:
convertedOutPath = os.path.join(args.output_dir, "converted_checkpoints", f"checkpoint-{global_step}")
logger.info(f"Converting checkpoint-{global_step} to diffusers model at {convertedOutPath}")
pipeline.save_pretrained(convertedOutPath)

logger.info(f"Generating {args.samples_per_checkpoint} samples at step {global_step} with prompt: {args.sample_prompt}")

# Allow a statically set or random seed for for generated samples
sampleSeed = args.sample_seed if args.sample_seed != -1 else torch.Generator(accelerator.device).seed()
sampleGenerator = torch.Generator(device=accelerator.device)

for sampleIdx in range(args.samples_per_checkpoint):
sampleGenerator.manual_seed(sampleSeed)
# Generate samples
with torch.cuda.amp.autocast(enabled=True):
images = pipeline(
prompt=args.sample_prompt,
num_images_per_prompt=1,
num_inference_steps=args.sample_steps,
generator=sampleGenerator,
width=args.resolution,
height=args.resolution).images

# Save samples to 'samples' folder in output directory
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
image_filename = os.path.join(args.output_dir, "samples", f"step-{global_step}_seed-{sampleSeed}_loss-{loss.detach().item()}_hash-{hash_image}.png")
image.save(image_filename)

sampleSeed += 1

# Remove interim pipeline reference
del pipeline
Comment on lines +856 to +903
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @subpanic, could you maybe try to use the mechanism of https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py#L902 here for generation? It should be a bit easier than what we currently have :-)



logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
Expand Down