diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 414ecdeb1fb7..3d2e694a1015 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -417,6 +417,16 @@ def parse_args(input_args=None): ), ) + parser.add_argument( + "--offset_noise", + action="store_true", + default=False, + help=( + "Fine-tuning against a modified noise" + " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." + ), + ) + if input_args is not None: args = parser.parse_args(input_args) else: @@ -943,7 +953,12 @@ def load_model_hook(models, input_dir): latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + if args.offset_noise: + noise = torch.randn_like(latents) + 0.1 * torch.randn( + latents.shape[0], latents.shape[1], 1, 1, device=latents.device + ) + else: + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)