From bdc05306f762b102bea7a15dd88e4d7c6d552708 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 30 Apr 2023 09:24:31 -0400 Subject: [PATCH 1/2] Added input pretubation --- examples/text_to_image/train_text_to_image.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1d62cb7f816d..96ed219e1f31 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -112,6 +112,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--input_pretubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -801,7 +802,8 @@ def collate_fn(examples): noise += args.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - + if args.input_pretubation: + new_noise = noise + args.input_pretubation*torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -809,7 +811,10 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if args.input_pretubation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] From c9542aa715afd56fa3fc678935c88fde7baafebb Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 30 Apr 2023 09:29:47 -0400 Subject: [PATCH 2/2] Fixed spelling --- examples/text_to_image/train_text_to_image.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 96ed219e1f31..f9592e5adca3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -112,7 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--input_pretubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1.") + parser.add_argument( + "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1." + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -802,8 +804,8 @@ def collate_fn(examples): noise += args.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - if args.input_pretubation: - new_noise = noise + args.input_pretubation*torch.randn_like(noise) + if args.input_pertubation: + new_noise = noise + args.input_pertubation * torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -811,7 +813,7 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - if args.input_pretubation: + if args.input_pertubation: noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)